首页 > 解决方案 > next() 总是给出与 KFold 生成器相同的索引

问题描述

我正在关注此线程以使用 sklean 的 KFold 生成用于交叉验证的 kfold 索引。

from sklearn.model_selection import KFold
import numpy as np

X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([1, 2, 3, 4, 5])

当我使用 for 循环时,一切正常:

for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)

给我:

TRAIN: [1 2 3 4] TEST: [0]
TRAIN: [0 2 3 4] TEST: [1]
TRAIN: [0 1 3 4] TEST: [2]
TRAIN: [0 1 2 4] TEST: [3]
TRAIN: [0 1 2 3] TEST: [4]

但是,当我使用 时next(),无论我运行多少次,我总是得到相同的索引:

train_idx, test_idx = next(kf.split(X))
print(train_idx, test_idx)

[1 2 3 4] [0]

有什么我想念的吗?谢谢

标签: pythonscikit-learn

解决方案


如评论中所述,您需要next()调用split().

代码尝试:

from sklearn.model_selection import KFold
import numpy as np

X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([1, 2, 3, 4, 5])

kf = KFold(n_splits=5)

randomIter = kf.split(X)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)
train_idx, test_idx = next(randomIter)
print(train_idx, test_idx)

推荐阅读