首页 > 解决方案 > 使用表示索引的 1D 长张量选择 3D Pytorch Tensor 的特定索引

问题描述

所以我有一个 M x B x C 的张量,其中 M 是模型的数量,B 是批次,C 是类别,每个单元格是给定模型和批次的类别的概率。然后我有一个正确答案的张量,它只是大小为 B 的一维,我们称之为“t”。如何使用大小为 B 的 1D 只返回 M x B x 1,其中返回的张量只是正确类的值?假设我尝试过的 M x B x C 张量被称为“blah”

blah[:, :, C]

for i in range(M):
    blah[i, :, C]

blah[:, C, :]

前 2 只返回每个切片的第 3 维中索引 t 的值。最后一个返回第二维中 t 个索引处的值。我该怎么做呢?

标签: pythonpytorchtensor

解决方案


我们可以通过结合高级和基本索引来获得想要的结果

import torch

# shape [2, 3, 4]
blah = torch.tensor([
    [[ 0,  1,  2,  3],
     [ 4,  5,  6,  7],
     [ 8,  9, 10, 11]],
    [[12, 13, 14, 15],
     [16, 17, 18, 19],
     [20, 21, 22, 23]]])

# shape [3]
t = torch.tensor([2, 1, 0])
b = torch.arange(blah.shape[1]).type_as(t)

# shape [2, 3, 1]
result = blah[:, b, t].unsqueeze(-1)

这导致

>>> result
tensor([[[ 2],
         [ 5],
         [ 8]],
        [[14],
         [17],
         [20]]])

推荐阅读