首页 > 解决方案 > 从tensorflow记录数据集中批量提取图像数据集

问题描述

我最近开始使用 tensorflow 研究 CNN,发现 tfrecords 对加快训练非常有帮助,但是我在数据 API 方面遇到了困难。
解析后,我的数据集由(图像,标签)元组组成,这对于训练来说很好,但是我试图在另一个数据集中提取图像以调用 keras.predict() 。

我试过这个解决方案:

test_set = get_set_tfrecord(test_path, _parse_function, num_parallel_calls = 4)

lab = []
f = True
for image, label in test_set.take(600):
    if f:
      img = tf.data.Dataset.from_tensors(image)
      f = False
    else:
      img = img.concatenate(tf.data.Dataset.from_tensors(image))
    lab.append(label.numpy())

天真,不是很好的代码,但它的工作原理是为了执行连接(即堆叠),它将每个图像加载到 RAM 中。

这样做的正确方法是什么?

标签: pythontensorflowmachine-learningkerascomputer-vision

解决方案


您可以使用map 来自. tf.data.Dataset您可以编写以下代码。

result = test_set.map(lambda image, label: image)
# You can iterate and check what you have received at the end.
# I expect only the images.
for image in result.take(1):
    print(image)

我希望使用上面的代码可以解决您的问题,并且这个答案可以很好地为您服务。


推荐阅读