python-3.x - How does test_size relate when used in python sklearn for a 10 fold cross validation
问题描述
I am trying to implement a ML algorithm in which I would like to use a 10 fold cross validation process but I would just like to get confirmation if my procedure is correct.
I am doing a binary classification and have about 50 samples of each class in each of the 10 folders that I created, called fold 1
, fold 2
, and so on.
My sklearn
command is:
x_train, x_test, y_train, y_test = train_test_split(X, yy, test_size=0.3, random_state=1000)
Am I totally wrong here and this procedure is actually just doing a 30% test and 70% train process? For the 10 fold cross validation, I should be using:
from sklearn.model_selection import KFold
kf = KFold(n_splits=2, random_state=42, shuffle=True)
Thanks!
解决方案
Am I totally wrong here and this procedure is actually just doing a 30% test and 70% train process?
Yes, setting test_size=0.3
gives you a 30% test size and a 70% train size. We know this from reading the documentation.
test_size float or int, default=None
If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split
If you're repeating this 10 times with different random_state
, then there will be some repeated elements in the test set among the 10 repetitions. The purpose of k-fold cross-validation is to create k disjoint sets, and each set used in turn as a holdout. Your procedure is not a cross-validation, because the sets you've produced by this procedure will never be disjoint (you can prove this with the pigeonhole principle).
kf = KFold(n_splits=2, random_state=42, shuffle=True)
This isn't a 10-fold CV because n_splits=2
. We know this from reading the documentation. The argument n_splits
should be the number of folds. You've said you want 10 splits.
推荐阅读
- javascript - 每次添加时如何用对象数组填充html表
- android - Kivy 应用程序在开发中运行良好,但在生成发布时抛出错误
- docker - 在 Ubuntu 20.04 中启动时禁用 Docker
- python - 从字符串中提取日期并保存在新的 pandas DataFrame 列中
- c# - 如何在 XAML 和 C# (UWP) 中进行应用范围的明暗模式切换?
- python - 如何在 CSV 文件中搜索文本并打印到新的 CSV 文件
- java - 将 Spring Boot 后端 + oauth 与前端应用程序集成的问题
- excel - VBA 公式数组 A1 或 R1C1
- python - 与 `filter(not function, iter)` 相比,使用 `itertools.filterfalse()` 有什么效率优势吗?
- reactjs - 如何导入 scss 文件并在 React 应用程序中全局使用它们?