python - Tensorflow 中复杂的切片操作
问题描述
我被 TensorFlow 上的切片操作困住了。我想做的是在 Numpy 中是这样的,
>>> a = np.arange(24).reshape((4,6))
>>> a
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
>>> print(a[[2,3],[0,1]])
array([12, 19])
然而在 TensorFlow 中,
>>> a = tf.Variable(np.arange(24).reshape((4,6)))
>>> with tf.Session() as sess:
... sess.run(tf.global_variables_initializer())
... print(sess.run(a[[2,3],[0,1]]))
我有一个错误说TypeError: can only concatenate list (not "int") to list
。有没有办法在 Tensorflow 中执行这种切片?
谢谢你。
解决方案
这是一种方式。但是我已经重新组织了索引([2,0],[3,1]
)。
a = tf.Variable(np.arange(24).reshape((4, 6)))
sess.run(tf.global_variables_initializer())
print(sess.run(tf.gather_nd(a, [[2,0],[3,1]])))
输出是
[12 19]