首页 > 解决方案 > 稀疏和分类类的 Keras ImageDataGenerator 流类索引

问题描述

对于 keras 的ImageDataGenerator flow_*方法,它需要用于 categorical 和 sparse 的类索引的字符串化版本class_mode。我有看起来像['0','1',...,'10','11',...]的类标签,不幸的结果是 Keras 以字符串字母顺序对这些标签进行索引:

例如:

datagen = ImageDataGenerator(
        rotation_range=0,
        width_shift_range=0,
        height_shift_range=0,
        rescale=None,
        shear_range=0,
        zoom_range=0,
        horizontal_flip=False,
        preprocessing_function=preprocessor,
        fill_mode='nearest')

test_generator = datagen.flow_from_dataframe(
    dataframe=dfTest,
    directory=None,
    x_col="filePath",
    y_col="ycat",
    target_size=SIZE,
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    shuffle=False)

    print(test_generator.class_indices)

给出:

{'0': 0,
 '1': 1,
 '10': 2,
 '11': 3,
 ...,
 '2': 12,
 '20': 13,
 '21': 14,
 '22': 15,
 '3': 16,
 '4': 17,
 '5': 18,
 '6': 19,
 '7': 20,
 '8': 21,
 '9': 22}

理想情况下,我希望看到:

{'0': 0,
 '1': 1,
 '2': 2,

 ...,
 }

我考虑过手动更改test_generator.class_indices,但我不确定这样做是否安全,因为在初始化后,生成器已经预先计算了数据集的类标签。

flow_*在不重写方法的情况下有一个很好的解决方案吗?

标签: tensorflowkeras

解决方案


classes您可以使用参数根据您的要求进行设置。但是目前classes的论点正在起作用,flow_from_directory但在flow_from_dataframe. 我可以在这里在 keras 看到一张票。它应该在即将发布的 tensorflow 版本中得到修复。

flow_from_directory 示例:这里的 class_indices 将是 {'dogs': 0, 'cats': 1} 而不是 {'cats': 0, 'dogs': 1}。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam

import os
import numpy as np
import matplotlib.pyplot as plt

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'

path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')  # directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')  # directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')  # directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')  # directory with our validation dog pictures

num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))

num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))

total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val

batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150

train_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our validation data

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary',
                                                           classes = ['dogs','cats'])

print(train_data_gen.class_indices)

输出 -

Found 2000 images belonging to 2 classes.
{'dogs': 0, 'cats': 1}

无法在 flow_from_dataframe 中工作:似乎是一个错误,应该尽快根据票证修复。同时,您可以通过将标签替换为字母(如 '0' 为 'a'、'1' 为 'b' 或修改要使用的代码来修复它flow_from_directory,但这又是一项开销任务。

%tensorflow_version 2.x
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import pandas as pd

import os
import numpy as np
import matplotlib.pyplot as plt

def append_ext(fn):
    return fn+".png"

traindf = pd.read_csv("/content/trainLabels.csv",dtype=str)

traindf["id"]=traindf["id"].apply(append_ext)

datagen=ImageDataGenerator(rescale=1./255.,validation_split=0.25)

train_generator=datagen.flow_from_dataframe(
                            dataframe=traindf,
                            directory="/content/train/",
                            x_col="id",
                            y_col="label",
                            subset="training",
                            batch_size=32,
                            seed=42,
                            shuffle=True,
                            class_mode="categorical",
                            target_size=(32,32),
                            classes=['truck', 'ship', 'horse', 'frog', 'dog', 'deer', 'cat', 'bird', 'automobile', 'airplane'])

print(train_generator.class_indices)

输出 -仍然以字母数字方式工作。

Found 37500 validated image filenames belonging to 10 classes.
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

希望这能回答你的问题。快乐学习。


推荐阅读