python - 在 Python 中使用线性代数进行线性回归
问题描述
我在 Python 中解释维基百科( https://en.wikipedia.org/wiki/Coefficient_of_determination )上的这些公式是否错误?以下是我尝试过的。
def ss_res(X, y, theta):
y_diff=[]
y_pred = X.dot(theta)
for i in range(0, len(y)):
y_diff.append((y[i]-y_pred[i])**2)
return np.sum(y_diff)
输出看起来正确,但数字略有偏差......就像几个小数点。
def std_error(X, y, theta):
delta = (1/(len(y)-X.shape[1]+1))*(ss_res(X,y,theta))
matrix1=matrix_power((X.T.dot(X)),-1)
thing2=delta*matrix1
thing3=scipy.linalg.sqrtm(thing2)
res=np.diag(thing3)
serr=np.reshape(res, (6, 1))
return serr
std_error_array=std_error(X,y,theta)
解决方案
你可能想要也可能不想要+1
你所说delta
的,取决于你是否X
包含一个“常量”列(即所有值 = 1)
否则,如果有点非 Pythonic,它看起来还不错。我很想把它们写成:
import numpy as np
from numpy.linalg import inv
from scipy.linalg import sqrtm
def solve_theta(X, Y):
return np.linalg.solve(X.T @ X, X.T @ Y)
def ss_res(X, Y, theta):
res = Y - (X @ theta)
return np.sum(res ** 2)
def std_error(X, Y, theta):
nr, rank = X.shape
resid_df = nr - rank
residvar = ss_res(X, Y, theta) / resid_df
var_theta = residvar * inv(X.T @ X)
return np.diag(sqrtm(var_theta))[:,None]
注意:这使用Python 3.5 风格的矩阵乘法运算符 @
而不是写出.dot()
这种算法的数值稳定性并不令人惊讶,您可能想看看使用 SVD 或 QR 分解。有一个平易近人的描述,你将如何使用 SVD 在:
John Mandel (1982)“在回归分析中使用奇异值分解” 10.1080/00031305.1982.10482771
我们可以通过创建一些虚拟数据来测试:
np.random.seed(42)
N = 20
K = 3
true_theta = np.random.randn(K, 1) * 5
X = np.random.randn(N, K)
Y = np.random.randn(N, 1) + X @ true_theta
并在上面运行上面的代码:
theta = solve_theta(X, Y)
sse = std_error(X, Y, theta)
print(np.column_stack((theta, sse)))
这使:
[[ 2.23556391 0.35678574]
[-0.40643163 0.24751913]
[ 3.14687637 0.26461827]]
我们可以使用以下方法进行测试statsmodels
:
import statsmodels.api as sm
sm.OLS(Y, X).fit().summary()
这使:
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
x1 2.2356 0.358 6.243 0.000 1.480 2.991
x2 -0.4064 0.248 -1.641 0.119 -0.929 0.116
x3 3.1469 0.266 11.812 0.000 2.585 3.709
这非常接近。
推荐阅读
- c# - 3d 文本上的 transform.LookAt 正在向后显示文本
- java - java.sql.Date错误如何避免呢?将日期插入我的数据库
- python - 我在反向欧拉中遇到双标量溢出和双标量错误中遇到的无效值
- c++ - 使用另一个四元数旋转四元数旋转游戏对象时出现问题
- php - MySQL 是否(或如何)本机地处理并发事务?
- macos - 如何使用 Homebrew 通过 Vim 8 安装 Python 3 支持
- c - 二进制文件打印最后一个元素两次?
- java - 在 for 循环中使用带有 JavaFX 的时间线关键帧
- javascript - 事件委托无法与 if else 语句一起正常工作
- javascript - jQuery 动画延迟在 Chrome 中不起作用