python - kfold交叉验证后如何绘制每个折叠的数据和模型拟合?
问题描述
我试图根据一个特征预测一个标签变量。两者似乎是高度线性相关的。我选择了一个线性回归模型来描述数据。我的代码输出显示了训练和测试数据的 R2 分数。我的模型表现良好,预计测试样本的两倍,其中 R2 为负数。我想绘制每个折叠的数据和模型的拟合,以了解出了什么问题。但是,从 python 编码的角度来看,我无法弄清楚如何做到这一点。
任何人都可以帮忙吗?
Test_scores = list()
Train_scores =list()
n_splits = 5
kfold = KFold(n_splits=n_splits
, shuffle=False)
for train_ix, test_ix in kfold.split(Feature_X):
Train_Feature_X, Test_Feature_X=Feature_X[train_ix], Feature_X[test_ix]
Train_label_X, Test_label_X= label_X[train_ix],label_X[test_ix]
model = LinearRegression()
model.fit(Train_Feature_X, Train_label_X)
pred_label_train = model.predict(Train_Feature_X)
acc_train = r2_score(Train_label_X, pred_label_train)
pred_label_test = model.predict(Test_Feature_X)
acc_test = r2_score(Test_label_X, pred_label_test)
Test_scores.append(acc_test)
Train_scores.append(acc_train)
print('> ', 'Train:'+ str(acc_train), "Test:"+ str(acc_test))
Test_mean, Test_std = np.mean(Test_scores), np.std(Test_scores)
Train_mean, Train_std = np.mean(Train_scores), np.std(Train_scores)
print('Mean of test: %.3f, Standard Deviation: %.3f' % (Test_mean, Test_std))
print('Mean of train: %.3f, Standard Deviation: %.3f' % (Train_mean, Train_std))
代码输出:
> Train:0.9948113361306588 Test:0.9715872368615199
> Train:0.9905854864161807 Test:0.9917503220348162
> Train:0.9888929852977923 Test:-4.996610921978263
> Train:0.990942242777374 Test:0.9960355777732937
> Train:0.9925744355834707 Test:0.9458246438971184
Mean of test: -0.218, Standard Deviation: 2.389
Mean of train: 0.992, Standard Deviation: 0.002
解决方案
您可以将绘图添加到循环周期中。
每次迭代您都可以访问训练测试折叠和预测,因此在打印值之前,print('> ', 'Train:'+ str(acc_train), "Test:"+ str(acc_test))
您可以执行以下操作:
fig, ax = plt.subplots(nrows=1, ncols=5)
curr_split = 1
for ...
plt.subplot(1, 5, curr_split)
plt.plot(x, y)
curr_split += 1
plt.show()
这将绘制 5 个子图,每个子图代表折叠。
请注意,这是您应该做的一般概述,请参阅以下链接中的文档plt.subplots()
推荐阅读
- json - Spring Boot Ajax 解析错误 - 无法将对象返回到 ajax 成功 fn
- linux - 查找进程启动时间“/proc/pid”创建时间或“ps ef”命令哪个更可靠
- python - 删除列索引pandas python
- replace - 使用 Autohotkey 替换剪贴板中的变音符号
- git - 获取标签之前/中的所有提交以使用 Git cmd 或 GitHub API
- javascript - 无法访问 React Redux Form 组件中的值
- python - 如果元素不是无,则在列表中添加元素的优雅方式
- python - Python 在 true 时将值增加 100
- python - 从 Python sys.exit 字符串在 gnu make 中分配变量
- python-3.x - 图表上的简单椭圆(pyplot、k-means Coursera 课程)