首页 > 解决方案 > 在 Python 中使用线性代数进行线性回归

问题描述

我在 Python 中解释维基百科( https://en.wikipedia.org/wiki/Coefficient_of_determination )上的这些公式是否错误?以下是我尝试过的。

ssres

def ss_res(X, y, theta):

    y_diff=[]
    y_pred = X.dot(theta)

    for i in range(0, len(y)):
        y_diff.append((y[i]-y_pred[i])**2)

    return np.sum(y_diff)

输出看起来正确,但数字略有偏差......就像几个小数点。

标准错误

def std_error(X, y, theta):


    delta = (1/(len(y)-X.shape[1]+1))*(ss_res(X,y,theta))
    matrix1=matrix_power((X.T.dot(X)),-1)
    thing2=delta*matrix1
    thing3=scipy.linalg.sqrtm(thing2)

    res=np.diag(thing3)
    serr=np.reshape(res, (6, 1))
    return serr

std_error_array=std_error(X,y,theta)

标签: pythonlinear-regressionlinear-algebra

解决方案


你可能想要也可能不想要+1你所说delta的,取决于你是否X包含一个“常量”列(即所有值 = 1)

否则,如果有点非 Pythonic,它看起来还不错。我很想把它们写成:

import numpy as np
from numpy.linalg import inv
from scipy.linalg import sqrtm

def solve_theta(X, Y):
    return np.linalg.solve(X.T @ X, X.T @ Y)

def ss_res(X, Y, theta):
    res = Y - (X @ theta)
    return np.sum(res ** 2)

def std_error(X, Y, theta):
    nr, rank = X.shape
    resid_df = nr - rank
    residvar = ss_res(X, Y, theta) / resid_df
    var_theta = residvar * inv(X.T @ X)
    return np.diag(sqrtm(var_theta))[:,None]

注意:这使用Python 3.5 风格的矩阵乘法运算符 @而不是写出.dot()

这种算法的数值稳定性并不令人惊讶,您可能想看看使用 SVD 或 QR 分解。有一个平易近人的描述,你将如何使用 SVD 在:

John Mandel (1982)“在回归分析中使用奇异值分解” 10.1080/00031305.1982.10482771

我们可以通过创建一些虚拟数据来测试:

np.random.seed(42)

N = 20
K = 3

true_theta = np.random.randn(K, 1) * 5
X = np.random.randn(N, K)
Y = np.random.randn(N, 1) + X @ true_theta

并在上面运行上面的代码:

theta = solve_theta(X, Y)
sse = std_error(X, Y, theta)

print(np.column_stack((theta, sse)))

这使:

[[ 2.23556391  0.35678574]
 [-0.40643163  0.24751913]
 [ 3.14687637  0.26461827]]

我们可以使用以下方法进行测试statsmodels

import statsmodels.api as sm

sm.OLS(Y, X).fit().summary()

这使:

                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
x1             2.2356      0.358      6.243      0.000       1.480       2.991
x2            -0.4064      0.248     -1.641      0.119      -0.929       0.116
x3             3.1469      0.266     11.812      0.000       2.585       3.709

这非常接近。


推荐阅读