python - next() 总是给出与 KFold 生成器相同的索引
问题描述
我正在关注此线程以使用 sklean 的 KFold 生成用于交叉验证的 kfold 索引。
from sklearn.model_selection import KFold
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([1, 2, 3, 4, 5])
当我使用 for 循环时,一切正常:
for train_index, test_index in kf.split(X):
print("TRAIN:", train_index, "TEST:", test_index)
给我:
TRAIN: [1 2 3 4] TEST: [0]
TRAIN: [0 2 3 4] TEST: [1]
TRAIN: [0 1 3 4] TEST: [2]
TRAIN: [0 1 2 4] TEST: [3]
TRAIN: [0 1 2 3] TEST: [4]
但是,当我使用 时next()
,无论我运行多少次,我总是得到相同的索引:
train_idx, test_idx = next(kf.split(X))
print(train_idx, test_idx)
[1 2 3 4] [0]
有什么我想念的吗?谢谢
解决方案
如评论中所述,您需要next()
调用split()
.
代码尝试:
from sklearn.model_selection import KFold
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([1, 2, 3, 4, 5])
kf = KFold(n_splits=5)
randomIter = kf.split(X)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
推荐阅读
- javascript - 如何检测旋转的矩形何时相互碰撞
- javascript - DataTables 子行 IE 和 Edge 混淆错误
- python - Python查找轮廓并绘制轮廓函数错误
- ubuntu-19.04 - Ubuntu 19.04 错误 404 Not Found [IP: 91.189.95.83 80] apt 更新错误
- python - Django Rest Framework:将destroy与pk以外的其他参数一起使用
- logic - 无法弄清楚为什么没有交易被执行
- javascript - html 文件不会将 onclick 属性读取为字符串
- vue.js - 创建新数据后如何更新 Strapi GraphQL 缓存?
- c++ - 无法使用安装程序下载和安装某些 MinGw 软件包
- echo - zsh:括号前保留斜线