首页 > 解决方案 > 如何从张量元组创建 TF 数据集?(和最佳实践)

问题描述

我在创建自定义数据集时遇到了一些问题。现在我设法获得了一个生成器功能:

def dataset_generator():
for path in pathlist_img:
    img, labels = process_path(path)
    yield img, labels 

它返回两个tensorflow张量:第一个shape=(720, 1280, 3), dtype=uint8, 第二个shape=(?, 14), dtype=float32, where "?" 意味着它取决于图像(它是一个对象检测数据集,因此识别的实例数量不固定)。

我想要一个数据集,其中我的图像与我的标签相关联,所以这就是我产生元组的原因。

问题是我的数据集

dataset = tf.data.Dataset.from_generator(dataset_generator, ((tf.uint8, tf.uint8, tf.uint8),(tf.float32, tf.float32)))

这只是一个

<FlatMapDataset shapes: ((<unknown>, <unknown>, <unknown>), (<unknown>, <unknown>)), types: ((tf.uint8, tf.uint8, tf.uint8), (tf.float32, tf.float32))>

并且似乎没有包含任何东西,或者至少我无法从中得到任何东西。

是否有一些从图像和标签文件构建数据集的最佳实践?我该如何解决这个问题?

标签: pythontensorflowtensorflow2.0tensorflow-datasets

解决方案


尝试像这样重写:

img = tf.ones((730, 1280, 3), dtype=tf.uint8)
label = tf.ones((tf.random.uniform(shape=[], minval=1, maxval=10), 14),
                dtype=tf.float32)
def gen():
    yield img,label

dataset = tf.data.Dataset.from_generator(
      gen,
      (tf.uint8, tf.float32),
)

在输出类型中,您应该只指定整体张量的整体 dtype,而不是每个维度的 dtype。


推荐阅读