首页 > 解决方案 > 为什么 PyTorch 和 NumPy 之间的逆结果不同

问题描述

这是数据:

https://github.com/GuokaiLiu/MyIssues/blob/main/dataMap.mat

这是代码:

import torch 
import numpy as np
import scipy.io as sio

def pinv1(A,reg):
    pAdd = reg*np.eye(A.shape[1])+A.T.dot(A)
    pInv = np.mat(pAdd).I
    pDot = pInv.dot(A.T)
    return pAdd, pInv, pDot

def pinv2(A,reg):
    pAdd = reg*torch.eye(A.shape[1])+torch.matmul(A.T,A)
    pInv = torch.inverse(pAdd).float()
    pDot = torch.matmul(pInv,A.T.float())
    return pAdd, pInv, pDot

npData = sio.loadmat('dataMap.mat')['dataMap']
ptData = torch.Tensor(dataMap)

a1,a2,a3 = pinv1(npData,2**-30)
b1,b2,b3 = pinv2(ptData,torch.Tensor([2**-30]))


print(torch.dist(torch.Tensor(np.array(a1)),b1,2))
print(torch.dist(torch.Tensor(np.array(a2)),b2,2))
print(torch.dist(torch.Tensor(np.array(a3)),b3,2))

结果如下:

tensor(0.0022)
tensor(9947374.)
tensor(99862.0469)

为什么 PyTorch 和 NumPy 之间的逆结果不同

标签: numpymatrixstatisticspytorchinverse

解决方案


使用 float64 保留所有变量。

看来逆运算对float64/float32很敏感

import torch 
import numpy as np
import scipy.io as sio
torch.set_default_dtype(torch.float64)

def pinv1(A,reg):
    A = A.numpy().astype('float64')
    pAdd11 = reg*np.eye(A.shape[1])
    pAdd12 = A.T.dot(A)
    pAdd = pAdd11 + pAdd12
    pInv = np.mat(pAdd).I
    pDot = pInv.dot(A.T)
    return pAdd11,pAdd12, pInv, pDot

def pinv2(A,reg):
    A = A.double()
    pAdd21 = reg*torch.eye(A.shape[1])
    pAdd22 = torch.matmul(A.T,A)
    pAdd = pAdd21 + pAdd22
    pInv = torch.inverse(pAdd)
    pDot = torch.matmul(pInv,A.T)
    return pAdd21,pAdd22, pInv, pDot

# npData = sio.loadmat('dataMap.mat')['dataMap']
# ptData = torch.Tensor(dataMap)

c = 2**-30
# c = np.array(np.float64(c))
np.array(c)
c1 = np.array([c])
c2 = torch.Tensor([2**-30])
a1,a2,a3,a4 = pinv1(dataMap,c1)
b1,b2,b3,b4 = pinv2(dataMap,c2)
# a1,a2,a3 = pinv2(npData,2**-30)


print(torch.dist(torch.from_numpy(a1),b1,2))
print(torch.dist(torch.from_numpy(a2),b2,2))
print(torch.dist(torch.from_numpy(a3),b3,2))
print(torch.dist(torch.from_numpy(a4),b4,2))

结果:

tensor(0.)
tensor(4.7235e-12)
tensor(0.0233) # How to improve this?
tensor(6.8622e-05)

推荐阅读