首页 > 解决方案 > ndarray/张量索引

问题描述

我有一个形状为 (NB, N, 2, 2) 的张量 A。

如果我有一个列表 B,其中包含我想保留在张量 A 中的长度为 NB 的索引,我应该怎么做?也就是说,我想根据 B 中的索引,每批次保留 1 个(N 个)元素。

我可以通过一个 for 循环来完成它,指定 A 中的批次i和 b 中的第i个元素。但是有没有一种矢量化的方式来做到这一点?

我尝试了 A[B] 或 A[B.unsqueeze(1)],两者都有索引错误。并且 A[:, B] 将为每个批次返回 NB 个元素。

例子:

A = Tensor([[[a 2x2 mat AAA1], [a 2x2 mat BBB1], [a 2x2 mat CCC1], [a 2x2 mat DDD1]], 
    [[a 2x2 mat AAA2], [a 2x2 mat BBB2], [a 2x2 mat CCC2], [a 2x2 mat DDD2]], 
    [[a 2x2 mat AAA3], [a 2x2 mat BBB3], [a 2x2 mat CCC3], [a 2x2 mat DDD3]]
  ])

B = [1, 3, 0]

预期输出:

Tensor([[[a 2x2 mat BBB1]], 
    [[a 2x2 mat DDD2]], 
    [[a 2x2 mat AAA3]]
    ])

标签: pythonpytorchtensor

解决方案


torch.gather前来救援。

准备您的索引列表,例如

# A.shape = (NB, N, 2, 2)
B = torch.tensor([1, 3, 0]) # should be of length NB
B = B[:, None, None, None].repeat(1, # your actual indecies in batch dim
                                  1, # indexing dim to be kept 1
                                  2, # these two must be repeated
                                  2)

最后,gather像这样使用

torch.gather(A, 1, B) # indexing along '1'-th dim

推荐阅读