首页 > 解决方案 > 来自 Keras 的 ImageDataGenerator 返回“TypeError:数据类型不理解”

问题描述

我使用 Keras ImageDataGenerator.flow_from_directory(...) 创建训练和测试数据集。然后我想用这些数据拟合model.fit()。在 Tensorflow 2.1 中它工作得非常好。但是,在 Tensorflow 2.2 中运行相同的代码会生成:TypeError: data type not understood. 您建议如何克服此问题并运行 TF2.2?

代码示例:

train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')

...

model.fit(train_data, epochs=50) # This generates an error in TF2.2, but in TF2.1 works fine.

在 TF2.2 中生成此错误的另一种方法是迭代生成器:

train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255., dtype=tf.float32)
train_data = train_gen.flow_from_directory(directory=os.path.join(current_dir, data, 'train/'), target_size=(width, height), class_mode='sparse')

for x,y in train_data:
    print(type(x), type(y))

标签: pythontensorflowimage-processingkeras

解决方案


问题出在 keras 版本上。以下配置导致错误。

keras 2.3.1
keras-preprocessing 1.1.2

更改为此版本后,一切正常:

keras 2.4.3
keras-preprocessing 1.1.0

推荐阅读