python - 对于大型矩阵,如何使用 scipy.linalg.lstsq 避免内存错误?
问题描述
我需要计算许多不同数据集的 Mahalobis 距离,并且我遇到了其中最大的 scipy.linalg.lstsq 函数的内存错误。是否有任何可能的方法来使用较慢的函数或分块这部分分析?
我有 8 个不同的 CPU 和 32 GB 的 RAM。有时我可以看到 RAM 使用率在内存错误(第 1 个打印行)之前达到 100%,但有时只有 1 或 2 个 CPU 达到 100%,并且在我收到内存错误(第 2 个打印行)之前 RAM 仅达到 40% .
当我收到错误时,打印语句给了我:
1123840 71836800 (8780, 16) (561225, 16)
1169920 103577856 (9140, 16) (809202, 16)
这些形状与 mahal Y 和 X 的输入的形状完全相同。两个数组都是 float64 并且我尝试减少到 float 32 但这给了我不同的结果。我也在输入上尝试了 np.around() ,但这并没有改变字节大小。
def mahal(Y, X):
""" Function translated directly from MATLAB. Tested to give equivalent outputs.
"""
[rx,cx] = X.shape
[ry,cy] = Y.shape
m = numpy.mean(X,axis=0);
M = m * numpy.ones([ry,1])
C = X - (m * numpy.ones([rx,1]))
Q,R = scipy.linalg.qr(C) #[Q,R] = qr(C,0)
print R.nbytes, (Y-M).nbytes, R.shape, (Y-M).shape
ri = linalg.lstsq(R.T,(Y-M).T)[0][0:cx] # For some reason, there are a lot of extra zeros that need to be removed
d = numpy.sum((ri * ri),axis=0).T*(rx-1)
return d
解决方案
使用这样的最小二乘法将导致创建一个巨大的中间矩阵。我没有 Matlab,但我可以在GNU Octave中找到等价物。将其翻译成 Python 给出:
import numpy as np
def mahal(Y, X):
xr, xc = X.shape
yr, yc = Y.shape
assert xc == yc, "X and Y must have the same number of columns"
mu = np.mean(X, 0)
x = X - mu
y = Y - mu
w = x.T @ x
w /= xr - 1
w = y @ np.linalg.pinv(np.linalg.cholesky(w).T)
return np.einsum('ij,ij->i', w, w)
GNU Octave 代码还包括一些测试,我们可以检查:
X = np.array([[1, 0], [0, 1], [1, 1], [0, 0]])
assert np.allclose(mahal(X, X), [1.5, 1.5, 1.5, 1.5])
assert np.allclose(mahal(X, X+1), [7.5, 7.5, 1.5, 13.5])
这对我来说成功了。接下来我们可以测试使用大矩阵,就像你做的那样,在合理的时间内给出答案:
X = np.random.rand(561225, 16)
Y = np.random.rand(809202, 16)
mahal(Y, X)
在 180 毫秒内运行mahal
函数
我正在使用 Python 3.5@
语法进行矩阵乘法,如果您使用的是 Python 的古老版本,则需要更改它。
推荐阅读
- javascript - 计算剩余天数百分比
- reactjs - 当状态为空对象时反应不重新渲染
- oracle - 应用程序无法通过 VPN 连接到 Oracle 数据库
- ruby-on-rails - 如何在 VS Code 中调试 jruby 应用程序
- python - 优化找到一组重叠段的最大最小整数的函数
- ios - 有没有办法在空的 uitextfield 中跟踪字符删除?
- html - flex-wrap 在 Internet Explorer 中不起作用
- python - 如何在不暂停代码的情况下延迟特定功能?
- database - 数据库触发器 - 更好的选择是什么
- php - 作曲家更新后,ActionColumn 内的图标被缩放