tensorflow - 如何从生成器创建固定长度的 tf.Dataset?
问题描述
我有一个生成器可以产生无限量的数据(随机图像裁剪)。我想创建一个tf.Dataset
基于假设 10,000 个第一个数据点并将其缓存以使用它们来训练模型?
目前,我有一个生成器需要 1-2 秒来创建每个数据点,这是主要的性能障碍。我要等一分钟才能生成一批 64 张图像(这个preprocessing()
功能很昂贵,所以我想重用结果)。
ds = tf.Dataset.from_generator()
方法允许我们创建这样的无限数据集。相反,我想使用生成器的 N 个第一个输出创建一个有限数据集,并将其缓存起来,如下所示:
ds = ds.cache()
.
替代解决方案是继续生成新数据,并在渲染生成器时使用缓存的数据点。
解决方案
您可以将Dataset.cache
函数与Dataset.take
函数一起使用来完成此操作。
如果一切都适合内存,那么只需执行以下操作即可:
def generate_example():
i = 0
while(True):
print ('yielding value {}'.format(i))
yield tf.random.uniform((64,64,3))
i +=1
ds = tf.data.Dataset.from_generator(generate_example, tf.float32)
first_n_datapoints = ds.take(n).cache()
现在请注意,如果我设置n
为 3 说然后做一些微不足道的事情,比如:
for i in first_n_datapoints.repeat():
print ('')
print (i.shape)
然后我看到输出确认前 3 个值已被缓存(yielding value {i}
对于生成的前 3 个值中的每一个,我只看到一次输出:
yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...
如果所有内容都不适合内存,那么我们可以将文件路径传递给缓存函数,它将生成的张量缓存到磁盘。
更多信息在这里:https ://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache
推荐阅读
- python - 裁剪数组中的值的更有效方法?
- c# - MSBuild 为后台任务生成无效的 AppxManifest.xml 文件
- laravel - Laravel:auth.php 中的存储密钥是什么?
- spring - Spring Data REST 重命名资源
- excel - 列出对象表命名范围
- node.js - 如何在 NodeJS 中将 JSON 传递给服务器?
- java - 在 Spring 中使用 @QueryParam
- javascript - 从另一个页面返回时如何强制 HTML 重新加载
- shell - 如何将 perl 转换为 shell 脚本?
- python - Dash plotly 自定义地图