python - 如何根据pytorch中的另一个张量选择索引
问题描述
这项任务似乎很简单,但我不知道该怎么做。
所以我有两个张量:
indices
具有 shape的索引张量(2, 5, 2)
,其中最后一个维度对应于 x 和 y 维度中的索引value
一个带有 shape的“值张量”(2, 5, 2, 16, 16)
,我希望用 x 和 y 索引选择最后两个维度
更具体地说,索引在 0 到 15 之间,我想得到一个输出:
out = value[:, :, :, x_indices, y_indices]
因此,输出的形状应该是(2, 5, 2)
。有人可以在这里帮助我吗?非常感谢!
编辑:
我尝试了用收集的建议,但不幸的是它似乎不起作用(我改变了尺寸,但没关系):
首先我生成一个坐标网格:
y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)
在下一步中,我将创建一些索引。在这种情况下,我总是取索引 1:
indices = torch.ones([1, 3, 2], dtype=torch.int64)
接下来,我正在使用您的方法:
indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)
最后,我手动为 x 和 y 坐标选择索引 1:
new_coords_manual = grid[:, :, :, 1, 1]
这将输出以下新坐标:
new_coords
tensor([[[-1.0000, -0.8667],
[-1.0000, -0.8667],
[-1.0000, -0.8667]]])
new_coords_manual
tensor([[[-0.8667, -0.8667],
[-0.8667, -0.8667],
[-0.8667, -0.8667]]])
如您所见,它仅适用于一维。你知道如何解决这个问题吗?
解决方案
您可以做的是将前三个轴展平并应用torch.gather
:
>>> grid.flatten(start_dim=0, end_dim=2).shape
torch.Size([6, 16, 16])
>>> torch.gather(grid.flatten(0, 2), axis=1, indices)
tensor([[[-0.8667, -0.8667],
[-0.8667, -0.8667],
[-0.8667, -0.8667]]])
如文档页面所述,这将执行:
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
推荐阅读
- android - onAdFailedtoLoad: 3 错误仅适用于 AdMob 的 Android 版 Flutter 应用
- amazon-redshift - 在 Redshift 中重新创建表时访问被拒绝
- aws-sdk - 使用 micronaut-bom1.2.10、graal19.2.1 和 aws-sdk2.10.56 构建应用程序时出现错误:org.apache.commons.logging.LogFactoryjava.lang.NoClassDefFoundError
- amazon-web-services - 是否可以使用 .Net Core 2.2 在 AWS Lambda 上使用 Kinesis FireHose 执行 PutRecord?
- python - 更新 Google 表格而不覆盖现有的 google 表格
- reactjs - 渲染选项中的条件 [材质表]
- ansible - 如果启用了 ufw,如何运行 ansible 任务?
- python - 不使用小写或大写的不区分大小写的字符串比较
- jquery - 我如何捕捉菜单的第 8 个元素以在新选项卡中打开
- javascript - 如何小写对象中键的第一个字母?