python - sklean mean_squared_error 忽略平方参数,multioutput='raw_values'
问题描述
sklearn 的均方误差函数的文档页面提供了一些有关如何使用该函数的示例。包括如何将它用于多输出数据和计算 RMSE。问题是这在计算多个输出的 RMSE 时不起作用。
这是我使用的代码:
from sklearn.metrics import mean_squared_error
y_true = [[0.5, 1],[-1, 1],[7, -6]]
y_pred = [[0, 2],[-1, 2],[8, -5]]
mean_squared_error(y_true, y_pred) # This returns the MSE
#out: 0.7083333333333334
mean_squared_error(y_true, y_pred, squared=False) # And the RMSE works too
#out: 0.8416254115301732
mean_squared_error(y_true, y_pred, multioutput='raw_values') # I can use the MSE for multiple outputs
#out: array([0.41666667, 1. ])
mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False) # But not the RMSE
#out: array([0.41666667, 1. ])
# However
import numpy as np
np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values')) # Numpy gives the correct results
#out: array([0.64549722, 1. ])
一些规格:
Python 3.6.8 (default, Oct 7 2019, 12:59:55)
[GCC 8.3.0] on linux
sklearn.__version__
'0.22'
np.__version__
'1.17.4'
我查看了源代码,但我不明白为什么这不起作用。
解决方案
这是一个已知的、现已关闭的问题
sklearn 0.23.2
,在本答案的当前版本中不会发生。这在 numpy 1.19.1 和 sklearn 0.23.2 中不可重现
mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False)
并np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))
返回相同的值。分辨率是升级。
如果升级不是一个选项:
- 在线: https ://github.com/scikit-learn/scikit-learn/blob/b194674c4/sklearn/metrics/_regression.py#L258 替换以下:
return output_errors
→return output_errors if squared else np.sqrt(output_errors)
推荐阅读
- r - 使用 IBrokers 和 R,检索多只股票的实时交易数据的合适方法是什么?
- python - apache Beam 广播一个 spacy 模型作为 Dataflow 中的侧面输入
- sql - 如何查询架构中没有成员的团队
- reactjs - 在 React 中使用 Swiper 断点时发生功能故障
- php - 强制形式主题
- javascript - Editor.JS SyntaxError:不能在模块外使用导入语句
- java - 如何使用 RestAssured Java 参数化 XML
- regex - 正则表达式只允许小数和 0 中的负数
- python - 更改python中的分配顺序时“超出时间限制”
- go - 两个 goroutine 之间的数据竞争