首页 > 解决方案 > torch.gather(...) 调用的设置结果

问题描述

我有一个形状为 n x m 的 2D pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二维进行索引,然后还为索引结果设置新值

例子:

data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

我想为每行选择指定的索引(这将是[1,5,7]但随后也将这些值设置为另一个数字 - 例如 42

我可以通过执行以下操作逐行选择所需的列:

data.gather(1, indices)
tensor([[1],
        [5],
        [7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather 
                                # does not use the same storage as the original tensor

这很好,但我现在想更改这些值,并且更改也会影响data张量。

我可以用它来做我想做的事情,但它似乎非常不符合pythonic:

max_index = torch.max(indices)
for i in range(0, max_index + 1):
  mask = (indices == i).nonzero(as_tuple=True)[0]
  data[mask, i] = 42
print(data)
# tensor([[ 0, 42,  2],
#         [ 3,  4, 42],
#         [ 6, 42,  8]])

关于如何更优雅地做到这一点的任何提示?

标签: pythonindexingpytorchtensor

解决方案


您正在寻找的是torch.scatter_选项value

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
将张量中的所有值写入张量srcself指定的索引处index。对于 中的每个值src,其输出index由 src fordimension != dim中的索引和 index for 中的相应值指定dimension = dim

以 2D 张量作为输入 和dim=1,运算为:
self[i][index[i][j]] = src[i][j]

虽然没有提到 value 参数......


使用value=42、 和dim=1,这将对数据产生以下影响:

data[i][index[i][j]] = 42

这里就地应用:

>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42,  2],
        [ 3,  4, 42],
        [ 6, 42,  8]])

推荐阅读