python - one_shot_iterator,占位符,无法捕获占位符
问题描述
我尝试one_shot_iterator
从数据集中制作一个。
我使用占位符来使用更少的 GPU 内存,并希望我只需初始化迭代器一次。
但我得到错误:
Traceback (most recent call last):
File "test_placeholder.py", line 18, in <module>
it = dset.make_one_shot_iterator()
File "<...>/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 205, in make_one_shot_iterator
six.reraise(ValueError, err)
File "<...>/site-packages/six.py", line 692, in reraise
raise value.with_traceback(tb)
ValueError: Cannot capture a placeholder (name:Placeholder,
type:Placeholder) by value.
测试:
import tensorflow as tf
import numpy as np
buf_size = 50
batch_size = 10
n_rows = 117
a = np.random.choice(7, size=n_rows)
b = np.random.uniform(0, 1, size=(n_rows, 4))
a_ph = tf.placeholder(a.dtype, a.shape)
b_ph = tf.placeholder(b.dtype, b.shape)
with tf.Session() as sess:
dset = tf.data.Dataset.from_tensor_slices((a_ph, b_ph))
dset = dset.shuffle(buf_size).batch(batch_size).repeat()
feed_dict = {a_ph: a, b_ph: b}
it = dset.make_one_shot_iterator()
n_batches = len(a) // batch_size
sess.run(it.initializer, feed_dict=feed_dict)
for i in range(n_batches):
a_chunk, b_chunk = it.get_next()
print(a_chunk, b_chunk)
什么地方出了错?
谢谢。
解决方案
查看导入数据的指南
“一次性迭代器是最简单的迭代器形式,它只支持对数据集进行一次迭代,不需要显式初始化。一次性迭代器处理几乎所有现有基于队列的输入管道支持的情况,但是它们不支持参数化。”
这就是您的错误的原因,因为此特定迭代器不支持任何带有占位符的参数化。我们可以使用 make_initializable_iterator 代替。
这是您进行修改的代码以及您正在寻找的结果。
buf_size = 50
batch_size = 10
n_rows = 117
a = np.random.choice(7, size=n_rows)
b = np.random.uniform(0, 1, size=(n_rows, 4))
a_ph = tf.placeholder(a.dtype, a.shape)
b_ph = tf.placeholder(b.dtype, b.shape)
with tf.Session() as sess:
dset = tf.data.Dataset.from_tensor_slices((a_ph, b_ph))
dset = dset.shuffle(buf_size).batch(batch_size).repeat()
feed_dict = {a_ph: a, b_ph: b}
it = dset.make_initializable_iterator()
n_batches = len(a) // batch_size
sess.run(it.initializer, feed_dict=feed_dict)
for i in range(n_batches):
a_chunk, b_chunk = it.get_next()
print(a_chunk, b_chunk)
结果:
Tensor("IteratorGetNext:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_1:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_1:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_2:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_2:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_3:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_3:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_4:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_4:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_5:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_5:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_6:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_6:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_7:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_7:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_8:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_8:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_9:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_9:1", shape=(?, 4), dtype=float64)
Tensor("IteratorGetNext_10:0", shape=(?,), dtype=int32) Tensor("IteratorGetNext_10:1", shape=(?, 4), dtype=float64)
推荐阅读
- python - 使用 Pandas 从 excel 列中获取所有粗体字
- c++ - std::vectors 的未解决问题
- python - 在 Mac 上打开任何 Python 文件时出现“IDLE 意外退出”
- c# - 单击按钮时无法打开表单
- c# - 为什么方法组会导致堆分配?
- node.js - Nodejs Paypal 支付 SDK 格式错误的 JSON 错误
- node.js - Sequelize - 同步多对多
- c# - 解析从流加载的 libvlcsharp 视频不起作用
- c++ - 如何解决这两个错误?'strlwr' - '但参数 2 的类型为 'int''
- symfony - Symfony:如何在编辑功能上检查控制器中的多种字符串类型