首页 > 解决方案 > 对于大型矩阵,如何使用 scipy.linalg.lstsq 避免内存错误?

问题描述

我需要计算许多不同数据集的 Mahalobis 距离,并且我遇到了其中最大的 scipy.linalg.lstsq 函数的内存错误。是否有任何可能的方法来使用较慢的函数或分块这部分分析?

我有 8 个不同的 CPU 和 32 GB 的 RAM。有时我可以看到 RAM 使用率在内存错误(第 1 个打印行)之前达到 100%,但有时只有 1 或 2 个 CPU 达到 100%,并且在我收到内存错误(第 2 个打印行)之前 RAM 仅达到 40% .

当我收到错误时,打印语句给了我:

1123840 71836800 (8780, 16) (561225, 16)

1169920 103577856 (9140, 16) (809202, 16)

这些形状与 mahal Y 和 X 的输入的形状完全相同。两个数组都是 float64 并且我尝试减少到​​ float 32 但这给了我不同的结果。我也在输入上尝试了 np.around() ,但这并没有改变字节大小。

def mahal(Y, X):
    """ Function translated directly from MATLAB.  Tested to give equivalent outputs.
    """
    [rx,cx] = X.shape
    [ry,cy] = Y.shape
    m = numpy.mean(X,axis=0);
    M = m * numpy.ones([ry,1]) 
    C = X - (m * numpy.ones([rx,1]))
    Q,R = scipy.linalg.qr(C) #[Q,R] = qr(C,0)
    print R.nbytes, (Y-M).nbytes, R.shape, (Y-M).shape
    ri = linalg.lstsq(R.T,(Y-M).T)[0][0:cx] # For some reason, there are a lot of extra zeros that need to be removed
    d = numpy.sum((ri * ri),axis=0).T*(rx-1)
    return d

标签: pythonscipy

解决方案


使用这样的最小二乘法将导致创建一个巨大的中间矩阵。我没有 Matlab,但我可以在GNU Octave中找到等价物。将其翻译成 Python 给出:

import numpy as np

def mahal(Y, X):
    xr, xc = X.shape
    yr, yc = Y.shape

    assert xc == yc, "X and Y must have the same number of columns"

    mu = np.mean(X, 0)
    x = X - mu
    y = Y - mu

    w = x.T @ x
    w /= xr - 1

    w = y @ np.linalg.pinv(np.linalg.cholesky(w).T)

    return np.einsum('ij,ij->i', w, w)

GNU Octave 代码还包括一些测试,我们可以检查:

X = np.array([[1, 0], [0, 1], [1, 1], [0, 0]])

assert np.allclose(mahal(X, X), [1.5, 1.5, 1.5, 1.5])
assert np.allclose(mahal(X, X+1), [7.5, 7.5, 1.5, 13.5])

这对我来说成功了。接下来我们可以测试使用大矩阵,就像你做的那样,在合理的时间内给出答案:

X = np.random.rand(561225, 16)
Y = np.random.rand(809202, 16)

mahal(Y, X)

在 180 毫秒内运行mahal函数

我正在使用 Python 3.5@语法进行矩阵乘法,如果您使用的是 Python 的古老版本,则需要更改它。


推荐阅读