首页 > 解决方案 > 如何根据类拆分机器学习数据?

问题描述

我正在尝试根据 Y_train 的不同类别来拆分 (X_train, Y_train)。X_train 由 50,000 个 25 X 25 图像组成,Y_train 由 50,000 个二进制分类(0 或 1)组成。我试图用下面的代码放置数据

def split(X_train, Y_train):
    if Y_train == 0:
       0_only = []
       0_only.append(X_train)

标签: pythondataset

解决方案


这可能会做你想做的事:

# Find the indices of the samples in Y_train that are zero
idx_zero = np.where(Y_train == 0)[0]

# Get subset of X_train and Y_train where Y_train is zero
X_train_zero = X_train[idx_zero]
Y_train_zero = Y_train[idx_zero]

然后你可以用np.where(Y_train == 1)[0].


推荐阅读