首页 > 解决方案 > 如何更改某些行和列中二维张量的值

问题描述

假设我有一个像这样的全零掩码张量:

mask = torch.zeros(5,3, dtype=torch.bool)

现在我想将以下和索引mask的交点处的值设置为:rowscolsTrue

rows = torch.tensor([0,2,4]) 
cols = torch.tensor([1,2])

我想产生以下结果:

tensor([[False, True,  True ],
        [False, False, False],
        [False, True,  True ],
        [False, False, False],
        [False, True,  True ]])

当我尝试以下代码时,我收到一个错误:

mask[rows, cols] = True

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

我怎样才能在 PyTorch 中有效地做到这一点?

标签: pythonpytorch

解决方案


您需要适当的形状才能使用torch.unsqueeze

mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

或者torch.reshape

mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

推荐阅读