python - 来自 keras 模型中图像列表的 TensorFlow 数据集
问题描述
我试图了解如何读取本地图像,将它们用作 TensorFlow数据集并使用 TF 数据集训练 Keras 模型。我正在关注 TF Keras MNIST TPU教程。我想阅读我的一组图像并对其进行训练的唯一区别。
假设我有图像列表(文件名)和相应的标签列表。
files = [...] # list of file names
labels = [...] # list of labels (integers)
images = tf.constant(files) # or tf.convert_to_tensor(files)
labels = tf.constant(labels) # or tf.convert_to_tensor(labels)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.shuffle(len(files))
dataset = dataset.repeat()
dataset = dataset.map(parse_function).batch(batch_size)
这parse_function
是一个简单的函数,它读取输入文件名并产生图像数据和相应的标签,例如
def parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image = tf.cast(image_decoded, tf.float32)
return image, label
此时我有dataset
一个 tf.data.Dataset 类型(更准确地说是 tf.data.BatchDataset),我将它trained_model
从教程传递给 keras 模型,例如
history = trained_model.fit(dataset, ...)
但此时代码因以下错误而中断:
AttributeError: 'BatchDataset' object has no attribute 'ndim'
该错误来自keras,它对给定的输入进行检查
from keras import backend as K
K.is_tensor(dataset) # which returns false
Keras 尝试确定输入的类型,并且由于它不是张量,因此它假定它是 numpy 数组并尝试获取其维度。这就是发生错误的原因。
我的问题如下:
- 我是否正确阅读了 TF 数据集?我在互联网上查找了很多示例,似乎我正在按照人们的建议阅读
- 为什么我的数据集不是张量?可能是我需要执行额外的转换,但不是 TF教程的情况
- 为什么在 TF教程中一切都适用于 tf 数据集,我真的看不出他们读取 MNIST 数据的方式(数据格式不同,但最终他们得到图像)和我在这里做的事情有什么不同。
任何建议将不胜感激。
请注意,即使 TF教程是关于 TPU 的,它的结构也可以在 TPU 和 CPU/GPU 上运行。
解决方案
原来问题在于使用 Keras 模型。TF 教程中的示例依赖于使用 tf.keras 模块构建的 Keras 模型(所有层、模型等都来自 tf.keras)。虽然我使用的模型(DenseNet)依赖于纯 keras 模块,即所有层都来自 keras 模块,而不是来自 tf.keras。这会导致 tf.data.Dataset 检查 ndim in fit keras 模型的方法。一旦我调整了我的 DenseNet 以使用 tf.keras 层,一切都会重新开始工作。
推荐阅读
- windows - 如何在 installshield 中为 windows 的西装安装程序添加两个 msi 包
- vba - 在 VBA 中从网站下载所有带有前缀的文件
- python - Python -> Boost Python + C++ 错误
- java - 想要根据映射到 json 的键隐藏对象的某些字段
- python - 如何在 matplotlib 中复制 gnuplot 的伪 3D 图,以便随时间分配蛋白质的二级结构?
- python - 如何使用 pandas DataFrame 在列轴连接中使用 join_axes?
- php - 如何使用 laravel 5.4 制作工作菜单选项卡
- django - 如何在 Django Rest Framework 中注册后自动生成电子邮件?
- php - 从商店页面 woocommerce 添加到购物车分组产品
- x86-64 - 将 x86-64 上的 long double 传递给可变参数函数