tensorflow - data.make_initializable_iterator() 抛出错误:TypeFetch 参数必须是字符串或张量
问题描述
我正在尝试在 tensorflow 1.0 中为图像分类问题编写一个简单的数据生成器。我有一个图像路径列表和相应的标签作为 2 个列表:路径和标签。
我正在使用以下代码来获取数据对象和迭代器。
dataset = (
tf.data.Dataset.from_tensor_slices((paths, labels))
.shuffle(buffer_size = len(paths))
.map(parse_fn, num_parallel_calls = 4)
.batch(32)
.prefetch(1)
)
train_iter = dataset.make_initializable_iterator()
train_next = train_iter.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(train_iter)
x, y = sess.run(train_next)
但我收到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-14-139e601c664d> in <module>()
25 with tf.Session() as sess:
26 sess.run(tf.global_variables_initializer())
---> 27 sess.run(train_iter)
28 x, y = sess.run(train_next)
29 print(x.shape, y.shape)
/home/surya/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
927 try:
928 result = self._run(None, fetches, feed_dict, options_ptr,
--> 929 run_metadata_ptr)
930 if run_metadata:
931 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/surya/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1135 # Create a fetch handler to take care of the structure of fetches.
1136 fetch_handler = _FetchHandler(
-> 1137 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1138
1139 # Run request and get response.
/home/surya/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
469 """
470 with graph.as_default():
--> 471 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
472 self._fetches = []
473 self._targets = []
/home/surya/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
269 if isinstance(fetch, tensor_type):
270 fetches, contraction_fn = fetch_fn(fetch)
--> 271 return _ElementFetchMapper(fetches, contraction_fn)
272 # Did not find anything.
273 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
/home/surya/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches, contraction_fn)
302 raise TypeError('Fetch argument %r has invalid type %r, '
303 'must be a string or Tensor. (%s)' %
--> 304 (fetch, type(fetch), str(e)))
305 except ValueError as e:
306 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fe5326ebf90> has invalid type <class 'tensorflow.python.data.ops.iterator_ops.Iterator'>, must be a string or Tensor. (Can not convert a Iterator into a Tensor or Operation.)
如果我将迭代器更改为
data_iter = dataset.make_one_shot_iterator()
为什么我会收到此错误,以及如何解决?谢谢!
解决方案
只需更改行:
sess.run(train_iter)
到:
sess.run(train_iter.initializer)
这是因为您要执行迭代器的初始化程序而不是迭代器本身。
推荐阅读
- c - 如何使用 Arduino 在 LCD 上无延迟()方法闪烁单个字符?
- php - (PHP) 为什么这是以下回显的输出?
- c# - 缺少 MSBuildCommunityTasks 文件夹
- dataprovider - 在 Genexus 中将 SDT 数据转换为 BC
- c++ - #define 一个宏和模板,它将运行任何类型的函数,包括 void,并返回平均运行时间?
- c++ - 我是否需要在头文件和源文件中指定调用约定
- spring-kafka - 从 Spring Boot 应用程序调用 ConsumConsumerSeekCallback 的 seek
- python - json 文件在使用 python 放入 zip 存档时损坏
- swift - CBCentralManager 状态最初注册为 .unknown
- arrays - 对数组中的奇数进行排序,同时保持偶数不变