首页 > 解决方案 > PyTorch Tensor.index_select() 如何评估张量输出?

问题描述

我无法理解索引的复杂性 - 张量的非连续索引是如何工作的。这是一个示例代码及其输出

import torch

def describe(x):
  print("Type: {}".format(x.type()))
  print("Shape/size: {}".format(x.shape))
  print("Values: \n{}".format(x))


indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))

返回输出为

类型:torch.LongTensor 形状/大小:torch.Size([2, 2]) 值:tensor([[0, 2], [3, 5]])

有人可以解释它是如何到达这个输出张量的吗?谢谢!

标签: pythonindexingpytorchtensor

解决方案


您正在从第一个轴 ( ) 上选择第一个 ( indices[0]is 0) 和第三个 ( indices[1]is 2) 张量。本质上,with的工作方式与使用 . 在第二个轴上进行直接索引相同。xdim=0torch.index_selectdim=1x[:, indices]

>>> x
tensor([[0, 1, 2],
        [3, 4, 5]])

所以选择列(因为你正在查看dim=1而不是dim=0)哪些索引在indices. 想象一下有一个简单的列表[0, 2]indices

>>> indices = [0, 2]

>>> x[:, indices[0]] # same as x[:, 0]
tensor([0, 3])

>>> x[:, indices[1]] # same as x[:, 2]
tensor([2, 5])

因此,将索引作为 a 传递torch.Tensor允许您直接索引索引的所有元素,即列02. 类似于 NumPy 的索引工作方式。

>>> x[:, indices]
tensor([[0, 2],
        [3, 5]])

这是另一个示例,可帮助您了解其工作原理。x如此定义,x = torch.arange(9).view(3, 3)我们有3(aka dim=0)和3列(aka dim=1)。

>>> indices
tensor([0, 2]) # namely 'first' and 'third'

>>> x = torch.arange(9).view(3, 3)
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> x.index_select(0, indices) # select first and third rows
tensor([[0, 1, 2],
        [6, 7, 8]])

>>> x.index_select(1, indices) # select first and third columns
tensor([[0, 2],
        [3, 5],
        [6, 8]])

注意torch.index_select(x, dim, indices)相当于x.index_select(dim, indices)


推荐阅读