首页 > 解决方案 > 为什么我们在 train_test_split 的两个数组中都包含目标类?

问题描述

X_train, test_df, y_train, y_test = train_test_split(result, y_true, stratify = y_true, test_size = 0.2)

在上面使用 train_test_split 的示例中,result是数据帧,并且y_true是从数据帧的目标类列形成的 numpy 数组。

我的问题是,如果我们已经分别给出了“y_true”,为什么我们还要将整个“结果”数据框作为 train_test_split 中的输入参数之一?我的意思是,我们不应该首先从“结果”数据框中排除目标类列吗?

标签: machine-learningscikit-learntrain-test-split

解决方案


Scikit-learn 支持 pandas,但 pandas 不是必需的。对于 numpy 数组,将特征和标签放在同一个数组中并不总是有意义的,因此是train_test_split函数的当前设计。因此,您需要确保您的resultDataFrame 及其拆分具有您想要的格式。如果y_trueresultDataFrame 的一部分,您可以(并且应该)选择在函数调用之前或之后将其排除。


推荐阅读