pytorch - Pytorch:根据索引张量从 3d 张量中选择列
问题描述
我有一个 3DM
维度张量[BxLxD]
和一个 1Didx
维度张量,[B,1]
其中包含 range 中的列索引(0, L-1)
。我想创建一个二维张量N
,[BxD]
这样N[i,j] = M[i, idx[i], j]
. 如何有效地做到这一点?
例子:
B,L,D = 2,4,2
M = torch.rand(B,L,D)
>
tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.randint(0, L, size = (B,))
>
tensor([3, 0])
N = get_N(M, idx)
Expected output:
>
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
谢谢。
解决方案
import torch
B,L,D = 2,4,2
def get_N(M, idx):
return M[torch.arange(B), idx, :].squeeze()
M = torch.tensor([[[0.0612, 0.7385],
[0.7675, 0.3444],
[0.9129, 0.7601],
[0.0567, 0.5602]],
[[0.5450, 0.3749],
[0.4212, 0.9243],
[0.1965, 0.9654],
[0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)
结果:
tensor([[0.0567, 0.5602],
[0.5450, 0.3749]])
沿二维切片。
推荐阅读
- applescript - AppleScript Photos 文件夹引用在子例程中不起作用 - 为什么?
- asp.net-core - .NET Core Razor 页面下拉菜单不返回值
- javascript - 在两个 MySql 表之间共享 user_id 值
- java - AWSToolkit for Eclipse --- 部署无服务器项目挂起 10%
- tsql - 为什么 ANTLR 语法文件中的类似规则会产生完全不同的树?
- amazon-web-services - AWS 组织账户管理
- cypher - neo4j n-levels的父子关系
- reactjs - 反应:我的方法没有根据其名称过滤我的数据
- qt - 为什么按钮不适合翻译?
- android - Expo Go 应用程序在 Android 上崩溃 | 数字文字后不允许直接使用标识符