python - cross_val_score 返回的分数与我的交叉验证分数的自定义实现之间的差异?
问题描述
我实现了我的自定义cross_val_score
功能。但结果与使用 sklearn 获得的结果不同cross_val_score
。
modelType = SGDClassifier(random_state=7)
cv2 = StratifiedKFold(5)
scores = cross_val_score(modelType, XTrainSc, yTrain, cv=cv2, scoring='accuracy', n_jobs=-1)
print(scores)
modelType = SGDClassifier(random_state=7)
ss=[]
for ti, vi in cv2.split(XTrainSc, yTrain):
print(str(len(ti))+" "+str(len(vi)))
model = clone(modelType)
model.fit(XTrainSc[ti], yTrain[ti])
preds = model.predict(XTrainSc[vi])
ss.append(np.mean(preds==yTrain[vi]))
print(ss)
这里scores
和ss
不相等。难道我做错了什么?
解决方案
StratifiedKfold
当它决定每个折叠的索引时,它也有随机性部分。因此,设置random_state
对于获得可重复性至关重要。
这是一个可重现的示例:
>>> from sklearn import datasets, linear_model
>>> from sklearn.model_selection import cross_val_score, StratifiedKFold
>>> from sklearn.base import clone
>>> import numpy as np
>>> X, y = datasets.load_breast_cancer(return_X_y=True)
model = linear_model.SGDClassifier(random_state=7)
cv2 = StratifiedKFold(5,random_state=0)
scores = cross_val_score(model, X, y, cv=cv2, scoring='accuracy', n_jobs=-1)
print(scores)
model = linear_model.SGDClassifier(random_state=7)
ss=[]
for ti, vi in cv2.split(X, y):
print(str(len(ti))+" "+str(len(vi)))
model = clone(model)
model.fit(X[ti], y[ti])
preds = model.predict(X[vi])
ss.append(np.mean(preds==y[vi]))
print(ss)
输出:
[0.91304348 0.70434783 0.45132743 0.38938053 0.38053097]
454 115
454 115
456 113
456 113
456 113
[0.9130434782608695, 0.7043478260869566, 0.45132743362831856, 0.3893805309734513, 0.3805309734513274]
推荐阅读
- datatable - 如何在 ComboBox SelectedIndexChanged 事件的文本框中获取 DataTable 列值?
- python-3.x - Holoviews - 选择抛出 AttributeError 的图
- xcode - Xcode Storyboards 现在只显示大纲
- highcharts - R Highcharter:通过单击突出显示/取消突出显示条形图中的条形
- python - 在 PyCharm 中运行带有 GPU 支持的 Tensorflow 内存不足
- csv - 如何将非常大的 PySpark 数据框导出为 CSV 文件?
- swift - 从 tableViewController 创建一个 segue 是在目标 viewController 中添加一个导航栏
- python - 当函数有附加参数时如何使用 scipy.optimize.bisect()?
- excel - Excel - 我们发现这个公式有问题。尝试单击公式选项卡上的插入功能
- c# - 如何通过单击按钮(asp.net)下载文件?