首页 > 解决方案 > 使用 keras ImageDataGenerator flow_from_dataframe 时,验证集仅从一个类中获取图像

问题描述

我有一个图像列表以及它所属的这种格式的类:

列表.txt

image1 good
image2 good
image3 good
.
.
.
image4 bad
image5 bad
image6 bad

我使用 ImageDataGenerator 来拆分验证数据:

train_datagen = ImageDataGenerator(rescale=1./255, validation_split = 0.25)

我使用熊猫从文件中读取数据框:

load_images = pd.read_csv("list.txt", delim_whitespace = True, header = None)
load_images.columns = ['filename','class']
load_images.columns = load_images.columns.str.strip()

trainDataframe = load_images    

我使用 flow_from_dataframe 创建训练和验证生成器:

train_generator = train_datagen.flow_from_dataframe(
        trainDataFrame,
        x_col = 'filename',
        y_col = 'class',
        directory = path_to_parent_folder_of_images,
        target_size=(inputHeight, inputWidth),
        batch_size=batch_size,
        class_mode='categorical',
        subset = 'training',
        save_to_dir = "path_to_folder\\training",
        shuffle = True)

validation_generator = train_datagen.flow_from_dataframe(
        trainDataFrame,
        x_col = 'filename',
        y_col = 'class',
        directory = path_to_parent_folder_of_images,
        target_size=(inputHeight, inputWidth),
        batch_size=batch_size,
        class_mode='categorical',
        subset= 'validation',
        save_to_dir = "path_to_folder\\validation",
        shuffle = True)

最后我训练模型:

model.fit_generator(
    train_generator,
    steps_per_epoch = train_generator.n // train_generator.batch_size,
    epochs = epochs,
    validation_data = validation_generator,
    validation_steps = validation_generator.n // validation_generator.batch_size,
    callbacks = callback_list)        

问题是验证集只包含来自bad类的图像。没有其他班级的图像。我使用了将图像保存到目录参数,我只看到一个类的图像。训练生成器看起来不错(有好有坏的图像)。由于此错误,我的验证准确度始终为 0 或 1。我在网上看到了一些例子,并试图遵循它们。似乎没有人面临这个问题,所以我不确定我做错了什么。

我正在使用这些版本:python - 3.7.4

张量流 - 2.0.0

喀拉拉邦 - 2.3.1

标签: pythontensorflowkeras

解决方案


我意识到 flow_from_dataframe() 从列表中获取前 25% 的图像,而不是随机选择。由于我的列表是排序的,这意味着所有好的类都在一起,坏的一起,它会获取前 25% 的图像并将其发送到验证集,并且由于列表是排序的,它总是将好的图像放在 val_set 中。我用了

from sklearn.utils import shuffle dataframes = shuffle(dataframes)

洗牌并将其发送到 flow_from_dataframe() 并解决了问题。


推荐阅读