首页 > 解决方案 > 带有轴参数的Tensorflow tf.gather

问题描述

我正在使用 tensorflowtf.gather从多维数组中获取元素,如下所示:

import tensorflow as tf

indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])

result = tf.gather(x, indices, axis=1)

with tf.Session() as sess:
    selection = sess.run(result)
    print(selection)

这导致:

[[1 2 2]
 [4 5 5]
 [7 8 8]]

我想要的是:

[1
 5
 8]

如何使用tf.gather在指定轴上应用单个索引?(与此答案中指定的解决方法相同的结果:https ://stackoverflow.com/a/41845855/9763766 )

标签: pythontensorflow

解决方案


您需要将 转换indicesfull indices,并使用gather_nd。可以通过这样做来实现:

result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))

推荐阅读