首页 > 解决方案 > 相当于torch.gather的张量流

问题描述

我有一个张量的形状(16, 4096, 3)。我有另一个形状指数张量(16, 32768, 3)。我正在尝试收集这些值dim=1。这最初是在 pytorch 中使用收集功能完成的,如下所示-

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

请注意,输出的大小bidx. 但是,当我应用gathertensorflow 的功能时,我得到了完全不同的输出。发现输出维度不匹配,如下所示 -

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

我也尝试使用tf.gather_nd但徒劳无功。见下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

为什么我会得到不同形状的张量?我想得到与 pytorch 计算的相同形状的张量。

换句话说,我想知道torch.gather的tensorflow等价物。

标签: pythontensorflowpytorch

解决方案


对于 2D 情况,有一种方法可以做到:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)

但是,对于 ND 情况,这种方法可能非常复杂


推荐阅读