首页 > 解决方案 > Tensorflow 中复杂的切片操作

问题描述

我被 TensorFlow 上的切片操作困住了。我想做的是在 Numpy 中是这样的,

>>> a = np.arange(24).reshape((4,6))
>>> a
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])
>>> print(a[[2,3],[0,1]])
array([12, 19])

然而在 TensorFlow 中,

>>> a = tf.Variable(np.arange(24).reshape((4,6)))
>>> with tf.Session() as sess:
...  sess.run(tf.global_variables_initializer())
...  print(sess.run(a[[2,3],[0,1]]))

我有一个错误说TypeError: can only concatenate list (not "int") to list。有没有办法在 Tensorflow 中执行这种切片?

谢谢你。

标签: pythontensorflow

解决方案


这是一种方式。但是我已经重新组织了索引([2,0],[3,1])。

a = tf.Variable(np.arange(24).reshape((4, 6)))

sess.run(tf.global_variables_initializer())

print(sess.run(tf.gather_nd(a, [[2,0],[3,1]])))

输出是

[12 19]


推荐阅读