首页 > 解决方案 > 如何沿对角线替换 PyTorch 张量中的特定值?

问题描述

例如,有一个 PyTorch 矩阵A

A = tensor([[3,2,1],[1,0,2],[2,2,0]])

我需要在对角线上用 1 替换 0,所以结果应该是:

tensor([[3,2,1],[1,1,2],[2,2,1]])

标签: pythonpytorchdiagonal

解决方案


您可以使用 torch 的内置对角函数来替换对角元素,如下所示:

mask = A.diagonal() == 0
A += torch.diag(mask)
>>> A
tensor([[3, 2, 1],
        [1, 1, 2],
        [2, 2, 1]])

如果要将 0 替换为另一个值,请更改maskmask * replace_value.


推荐阅读