首页 > 解决方案 > 如何从生成器创建固定长度的 tf.Dataset?

问题描述

我有一个生成器可以产生无限量的数据(随机图像裁剪)。我想创建一个tf.Dataset基于假设 10,000 个第一个数据点并将其缓存以使用它们来训练模型?

目前,我有一个生成器需要 1-2 秒来创建每个数据点,这是主要的性能障碍。我要等一分钟才能生成一批 64 张图像(这个preprocessing()功能很昂贵,所以我想重用结果)。

ds = tf.Dataset.from_generator()方法允许我们创建这样的无限数据集。相反,我想使用生成器的 N 个第一个输出创建一个有限数据集,并将其缓存起来,如下所示:

ds = ds.cache().


替代解决方案是继续生成新数据,并在渲染生成器时使用缓存的数据点。

标签: tensorflowtensorflow-datasets

解决方案


您可以将Dataset.cache函数与Dataset.take函数一起使用来完成此操作。

如果一切都适合内存,那么只需执行以下操作即可:

def generate_example():
  i = 0
  while(True):
    print ('yielding value {}'.format(i))
    yield tf.random.uniform((64,64,3))
    i +=1

ds = tf.data.Dataset.from_generator(generate_example, tf.float32)

first_n_datapoints = ds.take(n).cache()

现在请注意,如果我设置n为 3 说然后做一些微不足道的事情,比如:

for i in first_n_datapoints.repeat():
  print ('')
  print (i.shape)

然后我看到输出确认前 3 个值已被缓存(yielding value {i}对于生成的前 3 个值中的每一个,我只看到一次输出:

yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...

如果所有内容都不适合内存,那么我们可以将文件路径传递给缓存函数,它将生成的张量缓存到磁盘。

更多信息在这里:https ://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache


推荐阅读