python-3.x - tf.map_fn 中的张量流索引
问题描述
tf.map_fn
当我使用索引变量引用张量流中的图像数组时,我遇到了“切片索引必须是整数”错误。如果我不使用索引变量并对值进行硬编码,结果就可以了。以下是详细信息:
print(img_x.shape)
xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)
输出:
(4800,)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-62-15b6c15c4e0f> in <module>()
28 img_x = np.squeeze(img_x)
29 print(img_x.shape)
---> 30 xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)
31 # sess2 = tf.Session()
32 # xm.eval(session=sess2)
~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in map_fn(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name)
411 parallel_iterations=parallel_iterations,
412 back_prop=back_prop,
--> 413 swap_memory=swap_memory)
414 results_flat = [r.stack() for r in r_a]
415
~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
3094 swap_memory=swap_memory)
3095 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
-> 3096 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
3097 if maximum_iterations is not None:
3098 return result[1]
~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
2872 self.Enter()
2873 original_body_result, exit_vars = self._BuildLoop(
-> 2874 pred, body, original_loop_vars, loop_vars, shape_invariants)
2875 finally:
2876 self.Exit()
~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
2812 flat_sequence=vars_for_body_with_tensor_arrays)
2813 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
-> 2814 body_result = body(*packed_vars_for_body)
2815 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
2816 if not nest.is_sequence(body_result):
~/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in compute(i, tas)
401 """
402 packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
--> 403 packed_fn_values = fn(packed_values)
404 nest.assert_same_structure(dtype or elems, packed_fn_values)
405 flat_fn_values = output_flatten(packed_fn_values)
<ipython-input-62-15b6c15c4e0f> in <lambda>(i)
28 img_x = np.squeeze(img_x)
29 print(img_x.shape)
---> 30 xm = tf.map_fn(lambda i: img_x[i:i+end_range], tf.range(100), dtype=tf.float32)
31 # sess2 = tf.Session()
32 # xm.eval(session=sess2)
TypeError: slice indices must be integers or None or have an __index__ method
但是,如果我硬编码i变量,程序不会崩溃:
xm = tf.map_fn(lambda i: img_x[0:0+end_range], tf.range(100), dtype=tf.float32)
我会很感激任何帮助。谢谢。
编辑
似乎 img_x 是一个 numpy 数组,这就是它不起作用的原因。我使用将 img_x 转换为张量,tf.convert_to_tensor
并且效果很好。谢谢。
解决方案
推荐阅读
- javascript - body中选择对象的条件:JSON.stringify()
- c++ - 来自公共接口的兄弟姐妹的交叉引用
- python - 通过拆分另一个列表的索引在 Python 中创建一个列表可能吗?
- python - wagtail-generic-chooser 小部件 (NoReverseMatch)
- javascript - 如何使用javascript交换html div元素位置
- python - ValueError:无法将输入数组从形状(11,140,55)广播到形状(11,140,11)
- android - 每次询问时如何确保获得gps位置
- swift - Google Firebase:如何使用 Snapchat 对用户进行身份验证?
- c# - C# 设计 - 如何在没有空接口的情况下将类和枚举分组到列表中?
- python - Fbprophet 错误“系列”对象没有属性“非零”