首页 > 解决方案 > 如何在索引处添加到 pytorch 张量?

问题描述

我不得不承认,我对 scatter* 和 index* 操作有点困惑——我不确定它们中的任何一个都完全符合我的要求,这很简单:

给定一些二维张量

z = tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])

以及二维索引的列表(或张量?):

inds = tensor([[0, 0],
               [1, 1],
               [1, 2]])

我想在这些索引处向 z 添加一个标量(并有效地做到这一点):

znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
               [1., 4., 4., 1.],
               [1., 1., 1., 1.]])

如果必须,我可以使该标量成为任何形状的张量(所有元素 = 3),但我宁愿不...

标签: pytorch

解决方案


您必须为索引提供两个列表。第一个具有行位置,第二个具有列位置。在您的示例中,它将是:

z[[0, 1, 1], [0, 1, 2]] += 3

torch.Tensor 索引遵循 Numpy。有关更多详细信息,请参阅https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing


推荐阅读