首页 > 解决方案 > 如何避免在 Keras ImageDataGenerator 的验证拆分中增加数据?

问题描述

我正在使用以下生成器:

datagen = ImageDataGenerator(
    fill_mode='nearest',
    cval=0,
    rescale=1. / 255,
    rotation_range=90,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.5,
    horizontal_flip=True,
    vertical_flip=True,
    validation_split = 0.5,
)

train_generator = datagen.flow_from_dataframe(
    dataframe=traindf,
    directory=train_path,
    x_col="id",
    y_col=classes,
    subset="training",
    batch_size=8,
    seed=123,
    shuffle=True,
    class_mode="other",
    target_size=(64,64))


STEP_SIZE_TRAIN = train_generator.n // train_generator.batch_size

valid_generator = datagen.flow_from_dataframe(
    dataframe=traindf,
    directory=train_path,
    x_col="id",
    y_col=classes,
    subset="validation",
    batch_size=8,
    seed=123,
    shuffle=True,
    class_mode="raw",
    target_size=(64, 64))

STEP_SIZE_VALID = valid_generator.n // valid_generator.batch_size

现在的问题是验证数据也在增加,我想这不是你在训练时想要做的事情。我该如何避免这种情况?我没有两个用于训练和验证的目录。我想使用单个数据框来训练网络。有什么建议么?

标签: pythontensorflowkerasdata-augmentation

解决方案


我的朋友找到的解决方案是使用不同的生成器,但具有相同的验证拆分并且没有随机播放。

datagen = ImageDataGenerator(
    #featurewise_center=True,
    #featurewise_std_normalization=True,
    rescale=1. / 255,
    rotation_range=90,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.5,
    horizontal_flip=True,
    vertical_flip=True,
    validation_split = 0.15,
)

valid_datagen=ImageDataGenerator(rescale=1./255,validation_split=0.15)

然后您可以将两个生成器定义为

train_generator = datagen.flow_from_dataframe(
    dataframe=traindf,
    directory=train_path,
    x_col="id",
    y_col=classes,
    subset="training",
    batch_size=64,
    seed=123,
    shuffle=False,
    class_mode="raw",
    target_size=(224,224))


STEP_SIZE_TRAIN = train_generator.n // train_generator.batch_size

valid_generator = valid_datagen.flow_from_dataframe(
    dataframe=traindf,
    directory=train_path,
    x_col="id",
    y_col=classes,
    subset="validation",
    batch_size=64,
    seed=123,
    shuffle=False,
    class_mode="raw",
    target_size=(224, 224))

STEP_SIZE_VALID = valid_generator.n // valid_generator.batch_size

推荐阅读