首页 > 解决方案 > sklean mean_squared_error 忽略平方参数,multioutput='raw_values'

问题描述

sklearn 的均方误差函数的文档页面提供了一些有关如何使用该函数的示例。包括如何将它用于多输出数据和计算 RMSE。问题是这在计算多个输出的 RMSE 时不起作用。

这是我使用的代码:

from sklearn.metrics import mean_squared_error

y_true = [[0.5, 1],[-1, 1],[7, -6]]
y_pred = [[0, 2],[-1, 2],[8, -5]]

mean_squared_error(y_true, y_pred)  # This returns the MSE
#out: 0.7083333333333334

mean_squared_error(y_true, y_pred, squared=False)  # And the RMSE works too
#out: 0.8416254115301732

mean_squared_error(y_true, y_pred, multioutput='raw_values')  # I can use the MSE for multiple outputs
#out: array([0.41666667, 1.        ])

mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False)  # But not the RMSE
#out: array([0.41666667, 1.        ])

# However
import numpy as np

np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))  # Numpy gives the correct results
#out: array([0.64549722, 1.        ])

一些规格:

Python 3.6.8 (default, Oct  7 2019, 12:59:55)
[GCC 8.3.0] on linux

sklearn.__version__
'0.22'

np.__version__
'1.17.4'

我查看了代码,但我不明白为什么这不起作用。

标签: pythonnumpyscikit-learn

解决方案


  • 这是一个已知的、现已关闭的问题sklearn 0.23.2,在本答案的当前版本中不会发生。

  • 这在 numpy 1.19.1 和 sklearn 0.23.2 中不可重现

  • mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False)np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))返回相同的值。

  • 分辨率是升级。

  • 如果升级不是一个选项:

    • 在线: https ://github.com/scikit-learn/scikit-learn/blob/b194674c4/sklearn/metrics/_regression.py#L258 替换以下:
    • return output_errorsreturn output_errors if squared else np.sqrt(output_errors)

推荐阅读