首页 > 解决方案 > 张量操作 - 从给定张量创建位置张量

问题描述

我有一个输入张量,它在开始时填充为零,然后是一系列值。所以像:

x = torch.tensor([[0, 2, 8, 12],
                  [0, 0, 6, 3]])

我需要的是另一个具有相同形状并为填充保留 0 并为其余数字保留递增序列的张量。所以我的输出张量应该是:

y = ([[0, 1, 2, 3],
      [0, 0, 1, 2]])

我试过类似的东西:

MAX_SEQ=4
seq_start = np.nonzero(x)
start = seq_start[0][0]
pos_id = torch.cat((torch.from_numpy(np.zeros(start, dtype=int)).to(device), torch.arange(1, MAX_SEQ-start+1).to(device)), 0)
print(pos_id)

如果张量是 1 维的,但需要额外的逻辑来处理它以用于 2-D 形状,则此方法有效。这可以通过 np.nonzeros 返回一个元组来完成,我们可能会循环通过那些更新计数器或其他东西的元组。但是我确信必须有一个简单的张量操作,它应该在 1-2 行代码中完成,而且可能更有效。

帮助表示赞赏

标签: pythonpytorchtensortorch

解决方案


三个小步骤的可能解决方案:

  1. 查找每行的第一个非零元素的索引。这可以通过此处解释的技巧来完成此处适用于非二进制张量)。

    > idx = torch.arange(x.shape[1], 0, -1)
    tensor([4, 3, 2, 1])
    
    > xbin = torch.where(x == 0, 0, 1)
    tensor([[0, 1, 1, 1],
            [0, 0, 1, 1]])
    
    > xbin*idx
    tensor([[0, 3, 2, 1],
            [0, 0, 2, 1]])
    
    > indices = torch.argmax(xbin*idx, dim=1, keepdim=True)
    tensor([[1],
            [2]])
    
  2. 为结果张量创建一个排列(没有填充)。这可以通过应用torch.repeattorch.view来完成torch.arange call

    > rows, cols = x.shape
    > seq = torch.arange(1, cols+1).repeat(1, rows).view(-1, cols)
    tensor([[1, 2, 3, 4],
            [1, 2, 3, 4]])
    
  3. 最后 - 这是诀窍!- 对于每一行,我们用排列减去第一个非零元素的索引。然后我们屏蔽填充值并用零替换它们:

    > pos_id = seq - indices
    tensor([[ 0,  1,  2,  3],
            [-1,  0,  1,  2]])
    
    > mask = indices > seq - 1
    tensor([[ True, False, False, False],
            [ True,  True, False, False]])
    
    > pos_id[mask] = 0
    tensor([[0, 1, 2, 3],
            [0, 0, 1, 2]])
    

推荐阅读