首页 > 解决方案 > Sklearn:仅从每个目标类中获取少量记录

问题描述

我有一个具有多类分类(3 个类)的大型数据集,我想获取数据的子样本,即获取属于每个类的 200 条记录,然后根据该数据,我想拆分数据。

假设 3 个类是cat, dog, cow. 我想对数据子集应用拆分,其中从每个类的大型数据集中选择 200 条记录,cat以训练 ML 模型。dogcow

这是拆分数据的代码行:

# split data
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify = y, test_size = 0.3, 
                                                          random_state = 42)

我怎样才能选择Xy这样每个班级都有 200 条记录?

标签: pythonpandasdataframescikit-learn

解决方案


您可以groupby在“类”列中,然后您有几个选择:

  1. 如果您想随机选择 200 个,请使用sample聚合。

    df.groupby('class').sample(200, random_state=42)
    
  2. 如果不需要改组,则只需要每个的前 200 个,使用head聚合。

    df.groupby('class').head(200)
    

推荐阅读