首页 > 解决方案 > 在 PyTorch 中屏蔽 3D 张量中的前 k 个元素(每行不同的 k)

问题描述

我有一个M维度的张量[NxQxD]和一个索引的一维张量idx(大小N)。我想有效地创建一个mask维度[NxQxD]的张量mask[i,j,k] = 1 iff j <= idx[i],即我只想将idx[i]第一个维度保留Q在第二个维度 (dim=1) 中M,对于每一行i

谢谢!

标签: pythonpytorchtensor

解决方案


事实证明,这可以通过广播技巧来完成:

mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
mask_3d = mask[..., None] #(N,Q,1)
masked = mask.float() * data

推荐阅读