首页 > 解决方案 > 来自 keras 模型中图像列表的 TensorFlow 数据集

问题描述

我试图了解如何读取本地图像,将它们用作 TensorFlow数据集并使用 TF 数据集训练 Keras 模型。我正在关注 TF Keras MNIST TPU教程。我想阅读我的一组图像并对其进行训练的唯一区别。

假设我有图像列表(文件名)和相应的标签列表。

files = [...] # list of file names
labels = [...] # list of labels (integers)
images = tf.constant(files) # or tf.convert_to_tensor(files)
labels = tf.constant(labels) # or tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.shuffle(len(files))
dataset = dataset.repeat()
dataset = dataset.map(parse_function).batch(batch_size)

parse_function是一个简单的函数,它读取输入文件名并产生图像数据和相应的标签,例如

def parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_image(image_string)
    image = tf.cast(image_decoded, tf.float32)
    return image, label

此时我有dataset一个 tf.data.Dataset 类型(更准确地说是 tf.data.BatchDataset),我将它trained_model教程传递给 keras 模型,例如

history = trained_model.fit(dataset, ...)

但此时代码因以下错误而中断:

AttributeError: 'BatchDataset' object has no attribute 'ndim'

该错误来自keras,它对给定的输入进行检查

from keras import backend as K
K.is_tensor(dataset) # which returns false

Keras 尝试确定输入的类型,并且由于它不是张量,因此它假定它是 numpy 数组并尝试获取其维度。这就是发生错误的原因。

我的问题如下:

任何建议将不胜感激。

请注意,即使 TF教程是关于 TPU 的,它的结构也可以在 TPU 和 CPU/GPU 上运行。

标签: pythontensorflowkeras

解决方案


原来问题在于使用 Keras 模型。TF 教程中的示例依赖于使用 tf.keras 模块构建的 Keras 模型(所有层、模型等都来自 tf.keras)。虽然我使用的模型(DenseNet)依赖于纯 keras 模块,即所有层都来自 keras 模块,而不是来自 tf.keras。这会导致 tf.data.Dataset 检查 ndim in fit keras 模型的方法。一旦我调整了我的 DenseNet 以使用 tf.keras 层,一切都会重新开始工作。


推荐阅读