首页 > 解决方案 > tf.map_fn 中的张量流索引

问题描述

tf.map_fn当我使用索引变量引用张量流中的图像数组时,我遇到了“切片索引必须是整数”错误。如果我不使用索引变量并对值进行硬编码,结果就可以了。以下是详细信息:

print(img_x.shape)
xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)

输出:

(4800,)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-62-15b6c15c4e0f> in <module>()
     28 img_x = np.squeeze(img_x)
     29 print(img_x.shape)
---> 30 xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)
     31 # sess2 = tf.Session()
     32 # xm.eval(session=sess2)

~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in map_fn(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name)
    411         parallel_iterations=parallel_iterations,
    412         back_prop=back_prop,
--> 413         swap_memory=swap_memory)
    414     results_flat = [r.stack() for r in r_a]
    415 

~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
   3094         swap_memory=swap_memory)
   3095     ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
-> 3096     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
   3097     if maximum_iterations is not None:
   3098       return result[1]

~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
   2872       self.Enter()
   2873       original_body_result, exit_vars = self._BuildLoop(
-> 2874           pred, body, original_loop_vars, loop_vars, shape_invariants)
   2875     finally:
   2876       self.Exit()

~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
   2812         flat_sequence=vars_for_body_with_tensor_arrays)
   2813     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 2814     body_result = body(*packed_vars_for_body)
   2815     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2816     if not nest.is_sequence(body_result):

~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in compute(i, tas)
    401       """
    402       packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
--> 403       packed_fn_values = fn(packed_values)
    404       nest.assert_same_structure(dtype or elems, packed_fn_values)
    405       flat_fn_values = output_flatten(packed_fn_values)

<ipython-input-62-15b6c15c4e0f> in <lambda>(i)
     28 img_x = np.squeeze(img_x)
     29 print(img_x.shape)
---> 30 xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)
     31 # sess2 = tf.Session()
     32 # xm.eval(session=sess2)

TypeError: slice indices must be integers or None or have an __index__ method

但是,如果我硬编码i变量,程序不会崩溃:

xm = tf.map_fn(lambda i: img_x[0:0+end_range], tf.range(100), dtype=tf.float32)

我会很感激任何帮助。谢谢。

编辑 似乎 img_x 是一个 numpy 数组,这就是它不起作用的原因。我使用将 img_x 转换为张量,tf.convert_to_tensor并且效果很好。谢谢。

标签: python-3.xnumpytensorflow

解决方案


推荐阅读