首页 > 解决方案 > 如何返回基于 1D 掩码截断的 k-dim pytorch 张量

问题描述

因此,假设我有一个 k-dim 张量和一个 1-dim 掩码,这在 pytorch 中用于可变长度序列,并且我想返回一个张量,该张量表示直到掩码中第一个假值的元素。这是一个例子:

import torch

a = torch.tensor([[1,2],[3,4],[5,6],[0,0],[0,0],[0,0]])
b = torch.tensor([True,True,True,False,False,False])

# magic goes here, result of c should be:

print(c)
>>> [[1,2],[3,4],[5,6]]

在此示例中,输入张量是 2D,但它可以是 kd,在这些维度上具有任意数量的值。只有第一个维度需要匹配掩码维度。因此,torch.masked_select 不起作用,因为要截断的张量不像掩码那样是一维的,并且由于您不知道维度,因此挤压和解压也不是解决方案。

对于前 k 个元素,掩码始终为真,而对于其余元素,掩码始终为假,但如果您的解决方案不“依赖”于此,那很好。

这似乎人们会一直这样做,但我找不到任何地方已经回答了这个问题。

标签: pythonpytorchboolean

解决方案


您可以简单地将掩码作为切片索引传递给张量:

c = a[b]
>>> c
tensor([[1, 2],
        [3, 4],
        [5, 6]])

推荐阅读