numpy - 为什么 PyTorch 和 NumPy 之间的逆结果不同
问题描述
这是数据:
https://github.com/GuokaiLiu/MyIssues/blob/main/dataMap.mat
这是代码:
import torch
import numpy as np
import scipy.io as sio
def pinv1(A,reg):
pAdd = reg*np.eye(A.shape[1])+A.T.dot(A)
pInv = np.mat(pAdd).I
pDot = pInv.dot(A.T)
return pAdd, pInv, pDot
def pinv2(A,reg):
pAdd = reg*torch.eye(A.shape[1])+torch.matmul(A.T,A)
pInv = torch.inverse(pAdd).float()
pDot = torch.matmul(pInv,A.T.float())
return pAdd, pInv, pDot
npData = sio.loadmat('dataMap.mat')['dataMap']
ptData = torch.Tensor(dataMap)
a1,a2,a3 = pinv1(npData,2**-30)
b1,b2,b3 = pinv2(ptData,torch.Tensor([2**-30]))
print(torch.dist(torch.Tensor(np.array(a1)),b1,2))
print(torch.dist(torch.Tensor(np.array(a2)),b2,2))
print(torch.dist(torch.Tensor(np.array(a3)),b3,2))
结果如下:
tensor(0.0022)
tensor(9947374.)
tensor(99862.0469)
为什么 PyTorch 和 NumPy 之间的逆结果不同
解决方案
使用 float64 保留所有变量。
看来逆运算对float64/float32很敏感
import torch
import numpy as np
import scipy.io as sio
torch.set_default_dtype(torch.float64)
def pinv1(A,reg):
A = A.numpy().astype('float64')
pAdd11 = reg*np.eye(A.shape[1])
pAdd12 = A.T.dot(A)
pAdd = pAdd11 + pAdd12
pInv = np.mat(pAdd).I
pDot = pInv.dot(A.T)
return pAdd11,pAdd12, pInv, pDot
def pinv2(A,reg):
A = A.double()
pAdd21 = reg*torch.eye(A.shape[1])
pAdd22 = torch.matmul(A.T,A)
pAdd = pAdd21 + pAdd22
pInv = torch.inverse(pAdd)
pDot = torch.matmul(pInv,A.T)
return pAdd21,pAdd22, pInv, pDot
# npData = sio.loadmat('dataMap.mat')['dataMap']
# ptData = torch.Tensor(dataMap)
c = 2**-30
# c = np.array(np.float64(c))
np.array(c)
c1 = np.array([c])
c2 = torch.Tensor([2**-30])
a1,a2,a3,a4 = pinv1(dataMap,c1)
b1,b2,b3,b4 = pinv2(dataMap,c2)
# a1,a2,a3 = pinv2(npData,2**-30)
print(torch.dist(torch.from_numpy(a1),b1,2))
print(torch.dist(torch.from_numpy(a2),b2,2))
print(torch.dist(torch.from_numpy(a3),b3,2))
print(torch.dist(torch.from_numpy(a4),b4,2))
结果:
tensor(0.)
tensor(4.7235e-12)
tensor(0.0233) # How to improve this?
tensor(6.8622e-05)
推荐阅读
- javascript - 拆分字符串并创建以下数组的最佳方法:'1-2-3' -> ['1', '1-2', '1-2-3']
- firebase - Firebase 可调用函数:如何限制 CORS 来源
- in-app-purchase - IAP 提交审核按钮变灰 Xcode 14
- angular-ui-router - 当库具有路由延迟加载语法并尝试构建它时,Angular CLI 会抛出错误。对此的任何帮助表示赞赏
- regex - 如何使用 sed 删除括号但不是全部
- mapbox-gl-js - Mapbox Sheet Mapper 插件是否仅将 Google Sheet 列数据读取为字符串?
- python - 尝试将 pip 从版本 20.3.1 升级到版本 21.3。如何使用 `--user` 选项?
- python - 加快 scipy.stats.hypergeom 计算
- apache-kafka - 将消息从两个 Kafka 集群转发到另一个
- python - 如何在 Tkinter 上/中/与 Tkinter 一起使用 Canvas?