python - PyTorch Tensor.index_select() 如何评估张量输出?
问题描述
我无法理解索引的复杂性 - 张量的非连续索引是如何工作的。这是一个示例代码及其输出
import torch
def describe(x):
print("Type: {}".format(x.type()))
print("Shape/size: {}".format(x.shape))
print("Values: \n{}".format(x))
indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))
返回输出为
类型:torch.LongTensor 形状/大小:torch.Size([2, 2]) 值:tensor([[0, 2], [3, 5]])
有人可以解释它是如何到达这个输出张量的吗?谢谢!
解决方案
您正在从第一个轴 ( ) 上选择第一个 ( indices[0]
is 0
) 和第三个 ( indices[1]
is 2
) 张量。本质上,with的工作方式与使用 . 在第二个轴上进行直接索引相同。x
dim=0
torch.index_select
dim=1
x[:, indices]
>>> x
tensor([[0, 1, 2],
[3, 4, 5]])
所以选择列(因为你正在查看dim=1
而不是dim=0
)哪些索引在indices
. 想象一下有一个简单的列表 :[0, 2]
indices
>>> indices = [0, 2]
>>> x[:, indices[0]] # same as x[:, 0]
tensor([0, 3])
>>> x[:, indices[1]] # same as x[:, 2]
tensor([2, 5])
因此,将索引作为 a 传递torch.Tensor
允许您直接索引索引的所有元素,即列0
和2
. 类似于 NumPy 的索引工作方式。
>>> x[:, indices]
tensor([[0, 2],
[3, 5]])
这是另一个示例,可帮助您了解其工作原理。x
如此定义,x = torch.arange(9).view(3, 3)
我们有3行(aka dim=0
)和3列(aka dim=1
)。
>>> indices
tensor([0, 2]) # namely 'first' and 'third'
>>> x = torch.arange(9).view(3, 3)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> x.index_select(0, indices) # select first and third rows
tensor([[0, 1, 2],
[6, 7, 8]])
>>> x.index_select(1, indices) # select first and third columns
tensor([[0, 2],
[3, 5],
[6, 8]])
注意:torch.index_select(x, dim, indices)
相当于x.index_select(dim, indices)
推荐阅读
- amazon-dynamodb - 无服务器 dynamodb 启用连续备份
- sharepoint - 从 SharePoint 列表自动生成报告
- python - Pythonic 方法来制作一个可以接受可迭代或任意数量的参数的函数?
- java - 为什么我需要对对象数组中的特定引用进行强制转换?
- javascript - Dart DateTime.parse 删除 1 小时
- postgresql - Helm 中的 PostgreSQL:initdbScripts 参数
- python - AttributeError:“WebElement”对象没有属性(python)(硒)
- triggers - 有没有办法让崩溃的 Informix 触发器写出日志?
- typescript - 使用 VS Code 使 Debbuing Electron 应用程序运行良好
- python - 使用 Python 正确解析 PDF 段落