首页 > 解决方案 > scikit-learn:为什么这个 2 折交叉验证图看起来像 4 折交叉验证?

问题描述

文档使用此图来说明 2 折交叉验证

在此处输入图像描述

很明显测试集占了1/4,虽然代码是n_splits=2

>>> import numpy as np
>>> from sklearn.model_selection import KFold

>>> X = ["a", "b", "c", "d"]
>>> kf = KFold(n_splits=2)
>>> for train, test in kf.split(X):
...     print("%s %s" % (train, test))
[2 3] [0 1]
[0 1] [2 3]

为什么这个图看起来像 4 折交叉验证?这是一个不匹配的数字吗?

标签: pythonscikit-learn

解决方案


图片来自 4 交叉验证,您是对的。根据您的代码片段,您有两个拆分。

它看起来类似于 Kfold 的文档:

>>> import numpy as np
>>> from sklearn.model_selection import KFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([1, 2, 3, 4])
>>> kf = KFold(n_splits=2)
>>> kf.get_n_splits(X)
2
>>> print(kf)  
KFold(n_splits=2, random_state=None, shuffle=False)
>>> for train_index, test_index in kf.split(X):
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]

推荐阅读