首页 > 解决方案 > 如何在 pytorch 中处理运行时错误时矢量化矩阵求逆

问题描述

我需要在 pytorch 中反转一些矩阵。但是,有些矩阵是不可逆的,这导致代码抛出运行时错误如下,

matrices = torch.randn([5,3,3])
matrices[[2,3]] = torch.zeros([3,3])
inverses = torch.inverse(matrices)

RuntimeError: inverse_cpu: For batch 2: U(1,1) is zero, singular U.

对于这种情况,我有一种备用技术。但是,我无法弄清楚哪些矩阵会引发错误。目前,我已将代码替换为非矢量化版本,但它已成为瓶颈。

有没有办法在不放弃矢量化的情况下处理这个问题?

标签: pytorchvectorizationmatrix-inverse

解决方案


我能想到的最好方法是首先计算每个矩阵的行列式,然后计算具有abs(det)>0.

matrices = torch.randn([5,3,3])
matrices[[2,3]] = torch.zeros([3,3])
determinants = torch.det(matrices)
inverses = torch.inverse(matrices[determinants.abs()>0.])

您必须处理奇异矩阵的删除,但这应该不会太难,因为您有这些矩阵的索引值来自determinants.abs()==0.. 这允许您保持反演矢量化。


推荐阅读