tensorflow - 稀疏和分类类的 Keras ImageDataGenerator 流类索引
问题描述
对于 keras 的ImageDataGenerator
flow_*
方法,它需要用于 categorical 和 sparse 的类索引的字符串化版本class_mode
。我有看起来像['0','1',...,'10','11',...]
的类标签,不幸的结果是 Keras 以字符串字母顺序对这些标签进行索引:
例如:
datagen = ImageDataGenerator(
rotation_range=0,
width_shift_range=0,
height_shift_range=0,
rescale=None,
shear_range=0,
zoom_range=0,
horizontal_flip=False,
preprocessing_function=preprocessor,
fill_mode='nearest')
test_generator = datagen.flow_from_dataframe(
dataframe=dfTest,
directory=None,
x_col="filePath",
y_col="ycat",
target_size=SIZE,
batch_size=BATCH_SIZE,
class_mode='sparse',
shuffle=False)
print(test_generator.class_indices)
给出:
{'0': 0,
'1': 1,
'10': 2,
'11': 3,
...,
'2': 12,
'20': 13,
'21': 14,
'22': 15,
'3': 16,
'4': 17,
'5': 18,
'6': 19,
'7': 20,
'8': 21,
'9': 22}
理想情况下,我希望看到:
{'0': 0,
'1': 1,
'2': 2,
...,
}
我考虑过手动更改test_generator.class_indices
,但我不确定这样做是否安全,因为在初始化后,生成器已经预先计算了数据集的类标签。
flow_*
在不重写方法的情况下有一个很好的解决方案吗?
解决方案
classes
您可以使用参数根据您的要求进行设置。但是目前classes
的论点正在起作用,flow_from_directory
但在flow_from_dataframe
. 我可以在这里在 keras 看到一张票。它应该在即将发布的 tensorflow 版本中得到修复。
flow_from_directory 示例:这里的 class_indices 将是 {'dogs': 0, 'cats': 1} 而不是 {'cats': 0, 'dogs': 1}。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import matplotlib.pyplot as plt
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats') # directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs') # directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats') # directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs') # directory with our validation dog pictures
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255,brightness_range=[0.5,1.5]) # Generator for our validation data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary',
classes = ['dogs','cats'])
print(train_data_gen.class_indices)
输出 -
Found 2000 images belonging to 2 classes.
{'dogs': 0, 'cats': 1}
无法在 flow_from_dataframe 中工作:似乎是一个错误,应该尽快根据票证修复。同时,您可以通过将标签替换为字母(如 '0' 为 'a'、'1' 为 'b' 或修改要使用的代码来修复它flow_from_directory
,但这又是一项开销任务。
%tensorflow_version 2.x
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
def append_ext(fn):
return fn+".png"
traindf = pd.read_csv("/content/trainLabels.csv",dtype=str)
traindf["id"]=traindf["id"].apply(append_ext)
datagen=ImageDataGenerator(rescale=1./255.,validation_split=0.25)
train_generator=datagen.flow_from_dataframe(
dataframe=traindf,
directory="/content/train/",
x_col="id",
y_col="label",
subset="training",
batch_size=32,
seed=42,
shuffle=True,
class_mode="categorical",
target_size=(32,32),
classes=['truck', 'ship', 'horse', 'frog', 'dog', 'deer', 'cat', 'bird', 'automobile', 'airplane'])
print(train_generator.class_indices)
输出 -仍然以字母数字方式工作。
Found 37500 validated image filenames belonging to 10 classes.
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
希望这能回答你的问题。快乐学习。
推荐阅读
- java - BeanNotOfRequiredTypeException:名为 X 的 Bean 应该属于 X 类型,但实际上属于“com.sun.proxy.$Proxy”类型
- rspec - 为 RSpec rails 5 创建未定义的方法
- ios - 将 View 动画到 collectionView 中的选定单元格或 scrollView 中的 tableView
- amazon-s3 - 为什么 october cms 不通过使用 amazon-s3 上传图像将附件保存在 db 上?
- i2c - I2C 通信有什么问题?
- javascript - 将变量乘以 1 以检查它是否为数字
- javascript - javascript:搜索包括 html 在内的文本并更改其 css 样式
- ios - JSQMessageViewController textview 不显示链接
- node.js - AWS Lambda 构造函数错误“找不到 API 配置 lambda-2015-03-31”
- powershell - 在文本文件中添加新的文本行