python - 拆分 k 折,其中每折验证数据不包含重复项
问题描述
假设我有一个 pandas dataframe df
。df
包含 1,000 行。如下所示。
print(df)
id class
0 0000799a2b2c42d 0
1 00042890562ff68 0
2 0005364cdcb8e5b 0
3 0007a5a46901c56 0
4 0009283e145448e 0
... ... ...
995 04309a8361c5a9e 0
996 0430bde854b470e 0
997 0431c56b712b9a5 1
998 043580af9803e8c 0
999 043733a88bfde0c 0
它有 950 个数据 asclass 0
和 50 个数据 as class 1
。
现在我想再添加一列fold
,如下所示。
id class fold
0 0000799a2b2c42d 0 0
1 00042890562ff68 0 0
2 0005364cdcb8e5b 0 0
3 0007a5a46901c56 0 0
4 0009283e145448e 0 0
... ... ... ...
995 04309a8361c5a9e 0 4
996 0430bde854b470e 0 4
997 0431c56b712b9a5 1 4
998 043580af9803e8c 0 4
999 043733a88bfde0c 0 4
其中fold
列包含 5 个折叠(0,1,2,3,4)。每个折叠有 200 个数据,其中 190 个数据作为class 0
,10 个数据作为class 1
(这意味着保留每个 的样本百分比class
)。
我已经尝试过StratifiedShuffleSplit
,sklearn.model_selection
如下所示。
sss = StratifiedShuffleSplit(n_split=5, random_state=2021, test_size = 0.2)
for _, val_index in sss.split(df.id, df.class):
....
然后我将每个列表val_index
视为一个特定的折叠,但它最终给了我每个val_index
.
有人能帮我吗?
解决方案
您需要的是用于交叉验证的 kfold,而不是训练测试拆分。您可以使用StratifiedKFold
,例如您的数据集是这样的:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
np.random.seed(12345)
df = pd.DataFrame({'id' : np.random.randint(1,1e5,1000),
'class' :np.random.binomial(1,0.1,1000)})
df['fold'] = np.NaN
我们使用 kfold,像您一样遍历并分配折叠编号:
skf = StratifiedKFold(n_splits=5,shuffle=True)
for fold, [train,test] in enumerate(skf.split(df,df['class'])):
df.loc[test,"fold"] = fold
最终产品:
pd.crosstab(df['fold'],df['class'])
class 0 1
fold
0.0 182 18
1.0 182 18
2.0 182 18
3.0 182 18
4.0 181 19
推荐阅读
- javascript - Javascript正则表达式解析复杂的url字符串
- python - 提交时手动呈现 Django formset 重定向问题
- postgresql - Postgresql pg_profile 在创建快照时出错
- javascript - 轮播 css3 动画表现怪异
- python - 使用 TF 2.0 将 saved_model 转换为 TFLite 模型
- c++ - 使用 IMFSourceReader 进行音频流式传输(Microsoft Media Foundation)
- mongodb - MongoDB 聚合:将 mongodb 中的字段从 ObjectId 重命名为文字/字符串?
- angular - 有没有办法在 html 部分设置/绑定到组件变量
- facebook-graph-api - 如何使用 facebook 洞察图 API 获取在 fb 页面中喜欢帖子的人的个人资料名称/个人资料 ID
- istio - 这个例子中“VirtualService”的目的是什么?