python - 在 2.0 的会话中迭代 tf.data.Dataset 的正确方法
问题描述
我从youtube-8m 项目*.tfrecord
下载了一些数据。您可以使用以下命令下载数据的“小”部分:
curl data.yt8m.org/download.py | shard=1,100 partition=2/video/train mirror=us python
我试图了解如何使用新的 tf.data API。我想熟悉人们遍历数据集的典型方式。我一直在使用 TF 网站上的指南和这张幻灯片:Derek Murray 的幻灯片
这是我定义数据集的方式:
# Use interleave() and prefetch() to read many files concurrently.
files = tf.data.Dataset.list_files("./youtube_vids/*.tfrecord")
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100),
cycle_length=8)
# Use num_parallel_calls to parallelize map().
dataset = dataset.map(lambda record: tf.parse_single_example(record, feature_map),
num_parallel_calls=2) #
# put in x,y output form
dataset = dataset.map(lambda x: (x['mean_rgb'], x['id']))
# shuffle
dataset = dataset.shuffle(10000)
#one epoch
dataset = dataset.repeat(1)
dataset = dataset.batch(200)
#Use prefetch() to overlap the producer and consumer.
dataset = dataset.prefetch(10)
现在,我知道在急切执行模式下我可以
for x,y in dataset:
x,y
但是,当我尝试按如下方式创建迭代器时:
# A one-shot iterator automatically initializes itself on first use.
iterator = dset.make_one_shot_iterator()
# The return value of get_next() matches the dataset element type.
images, labels = iterator.get_next()
并与会话一起运行
with tf.Session() as sess:
# Loop until all elements have been consumed.
try:
while True:
r = sess.run(images)
except tf.errors.OutOfRangeError:
pass
我收到警告
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
所以,这是我的问题:
在会话中迭代数据集的正确方法是什么?只是v1和v2差异的问题吗?
此外,将数据集直接传递给估计器的建议意味着输入函数也有一个迭代器,如上面 Derek Murray 的幻灯片中定义的那样,对吗?
解决方案
As for Estimator API, no you don't have to specify iterator, just pass dataset object as input function.
def input_fn(filename):
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.shuffle().repeat()
dataset = dataset.map(parse_func)
dataset = dataset.batch()
return dataset
estimator.train(input_fn=lambda: input_fn())
In TF 2.0 dataset became iterable, so, just as warning message says, you can use
for x,y in dataset:
x,y
推荐阅读
- scala - Scala猫库验证-收集所有无效对象
- c# - 当前上下文中不存在“InitializeComponent”和“_FrameViews”
- html - 图像在 flexbox 中完美调整大小,直到我添加另一个元素
- azure - 有什么办法可以缓存 git 源文件?
- css - 如何在不丢失父级高度的情况下在 div 中居中项目?
- javascript - 使用 javascript 将类样式添加到 CKEditor 文本
- c++ - 解释未解决的外部 C++
- python - 如何旋转 seaborn 热图?
- reactjs - React 中的 onMouseEnter 事件有问题
- javascript - 当我保存存储的数据时,只有一个数据是另一个数据变为空。使用foreach怎么样?