首页 > 解决方案 > Pytorch:根据索引张量从 3d 张量中选择列

问题描述

我有一个 3DM维度张量[BxLxD]和一个 1Didx维度张量,[B,1]其中包含 range 中的列索引(0, L-1)。我想创建一个二维张量N[BxD]这样N[i,j] = M[i, idx[i], j]. 如何有效地做到这一点?

例子:

B,L,D = 2,4,2

M = torch.rand(B,L,D)

>

tensor([[[0.0612, 0.7385],
         [0.7675, 0.3444],
         [0.9129, 0.7601],
         [0.0567, 0.5602]],

        [[0.5450, 0.3749],
         [0.4212, 0.9243],
         [0.1965, 0.9654],
         [0.7230, 0.6295]]])


idx = torch.randint(0, L, size = (B,))

>

tensor([3, 0])

N = get_N(M, idx)

Expected output:

>

tensor([[0.0567, 0.5602], 
       [0.5450, 0.3749]])

谢谢。

标签: pytorchtensor

解决方案


import torch

B,L,D = 2,4,2

def get_N(M, idx):
    return M[torch.arange(B), idx, :].squeeze()

M = torch.tensor([[[0.0612, 0.7385],
                   [0.7675, 0.3444],
                   [0.9129, 0.7601],
                   [0.0567, 0.5602]],

                   [[0.5450, 0.3749],
                   [0.4212, 0.9243],
                   [0.1965, 0.9654],
                   [0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)

结果:

tensor([[0.0567, 0.5602],
        [0.5450, 0.3749]])

沿二维切片。


推荐阅读