首页 > 解决方案 > PyTorch 和 Numpy 中的张量入口选择逻辑发散

问题描述

描述

我正在设置一个torch.Tensor用于掩蔽目的。当尝试通过索引选择条目时,使用numpy.ndarraytorch.Tensor保存索引数据之间的行为是不同的。我希望能够访问解释差异的框架和相关文档中的设计。

复制步骤

环境

官方发布的容器中的 Pytorch 1.3:pytorch/pytorch:1.3-cuda10.1-cudnn7-devel

例子

假设我需要设置masktorch.Tensor具有形状的对象[3,3,3]并在条目(0,0,1)&处设置值(1,2,0)to 1。下面的代码解释了差异。

mask = torch.zeros([3,3,3])
indices = torch.tensor([[0, 1],
                        [0, 2],
                        [1, 0]])

mask[indices.numpy()] = 1 # Works
# mask[indices] = 1 # Incorrect result

我注意到,当使用mask[indices.numpy()]new torch.Tensorof shape[2]时,whilemask[indices]返回一个 new torch.Tensorof shape [3, 2, 3, 3],这表明张量切片逻辑存在差异。

标签: pythonnumpypytorch

解决方案


您会得到不同的结果,因为这就是在 Pytorch 中实现索引的方式。如果您将数组作为索引传递,那么它会被“解包”。例如:

indices = torch.tensor([[0, 1], [0, 2], [1, 0]])

mask = torch.arange(1,28).reshape(3,3,3)

# tensor([[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [ 7,  8,  9]],

#         [[10, 11, 12],
#          [13, 14, 15],
#          [16, 17, 18]],

#         [[19, 20, 21],
#          [22, 23, 24],
#          [25, 26, 27]]])

mask[indices.numpy()]等价于,即第mask[[0, 1], [0, 2], [1, 0]]i 行的indices.numpy()元素用于选择mask沿第 i 轴的元素。所以它返回tensor([mask[0,0,1], mask[1,2,0]]),即tensor([2, 16])

另一方面,当将张量作为索引传递时(我不知道数组和张量用于索引的这种区别的确切原因),它不像数组那样“解包”,并且第 i 行的元素张量indices的 用于选择mask沿轴 0 的元素。也就是说,mask[indices]相当于mask[[[0, 1], [0, 2], [1, 0]], :, :]

>>> mask[ind]

tensor([[[[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]],

         [[10, 11, 12],
          [13, 14, 15],
          [16, 17, 18]]],


        [[[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]],

         [[19, 20, 21],
          [22, 23, 24],
          [25, 26, 27]]],


        [[[10, 11, 12],
          [13, 14, 15],
          [16, 17, 18]],

         [[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]]]])

这基本上是tensor(mask[[0,1], :, :], mask[[0,2],: ,:], mask[[1,0], :, :])有形状indices.shape + mask[0,:,:].shape == (3,2,3,3)的。因此,整个“工作表”被选择并堆叠成新的维度。请注意,这不是一个新的张量,而是 的一个特殊视图mask。因此,如果您指定mask[indices] = 1, 带有这个 special indices,那么 的所有元素都mask将变为 1。


推荐阅读