scikit-learn - 当我使用“r2”作为评分时,sklearn cross_val_score() 返回 NaN 值
问题描述
我正在尝试使用 sklearn cross_val_score()。以下是我尝试过的示例:
# loocv evaluate random forest on the housing dataset
from numpy import mean
from numpy import std
from numpy import absolute
from pandas import read_csv
from sklearn.model_selection import LeaveOneOut
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestRegressor
# load dataset
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/housing.csv'
dataframe = read_csv(url, header=None)
data = dataframe.values
# split into inputs and outputs
X, y = data[:, :-1], data[:, -1]
print(X.shape, y.shape)
# create loocv procedure
cv = LeaveOneOut()
# create model
model = RandomForestRegressor(random_state=1)
# evaluate model
scores = cross_val_score(model, X, y, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
# force positive
scores = absolute(scores)
# report performance
print('MAE: %.3f (%.3f)' % (mean(scores), std(scores)))
上面的代码可以正常工作,没有任何问题。但是,当我scoring
变成时r2
,里面的所有值scores
都会变成nan
。
解决方案
问题是与作为评分功能 LeaveOneOut()
结合使用。将以这样一种方式拆分数据,即仅一个样本用于测试,其余样本用于训练。当您使用以下公式计算验证集时,问题就来了:r2
LeaveOneOut()
r2
分母变为零,因为n=1
(只有一个样本要验证)所以y_bar = y_i
因为平均值等于你拥有的一个数字,这会导致nan
你观察到。如果您cv = No. of data points
如下所示,这势必会发生:
# evaluate model
scores = cross_val_score(model, X[0:10], y[0:10], scoring='r2', cv=10, n_jobs=-1)
# force positive
scores = absolute(scores)
# report performance
print('MAE: %.3f (%.3f)' % (mean(scores), std(scores)))
MAE: nan (nan)
现在,当我为其设置其他值时,n
它可以正常工作:
# evaluate model
scores = cross_val_score(model, X[0:10], y[0:10], scoring='r2', cv=3, n_jobs=-1)
# force positive
scores = absolute(scores)
# report performance
print('MAE: %.3f (%.3f)' % (mean(scores), std(scores)))
MAE: 0.662 (0.229)
推荐阅读
- asp.net-core - 我们可以在通过 c# 代码使用 html 发送传真时使用外部 css 链接吗?
- asp.net-mvc - IIS 子应用程序根
- c++ - 删除智能指针
- python - Django Rest Framework - serializers.SlugRelatedField 不返回字段值
- r - 嵌套在R中的循环中,用于相同的变量
- python - 从串口读取调理
- javascript - 使用 puppeter 在 Angular 中下载到 PDF 它正在保存空 pdf
- node.js - '(req: ProtectedRequest, res: Response
, next: NextFunction)' 不能分配给“AsyncFunction”类型的参数 - linux - 如何使用 grep 命令找到不包含字母“e”的 8 个字母单词的数量?
- sql - 如何在 SQL Server 中沿列获取 json 字符串