首页 > 解决方案 > 如何从 ImageDataGenerator 获取 x_train?

问题描述

我正在研究图像分类 CNN,我想知道如何使用 ImageGenerator 获得 x_train 和 y_train 形式。这样做的原因是我想将我的模型拟合为 fit note fit.generator

trainDataGen = ImageDataGenerator(rescale= 1./255,
                                 rotation_range =30,
                                 width_shift_range=0.1,
                                 height_shift_range=0.1,
                                 shear_range=0.2,
                                 zoom_range=0.2,
                                 horizontal_flip=False,
                                 vertical_flip=False,
                                 fill_mode='nearest',                                
                                 )

trainGenSet = trainDataGen.flow_from_directory(
    path +'Train',
    target_size=(28,28),
    batch_size=32,
    class_mode='categorical',
    color_mode='grayscale',
)


x_train, y_train = trainGenSet.next()

由于批量大小为 32,print(x_train) 为 (32,28,28,1)。我总共有 5000 个训练数据集,我想为 print(x_train) 获取 (5000,28,28,1)

标签: deep-learningconv-neural-networkimage-classificationavassetimagegenerator

解决方案


推荐阅读