python - 梯度下降 - 作为列表和作为 numpy 数组的 theta 之间的区别
问题描述
我已经实现了梯度下降算法,并且根据我的 theta 是列表类型还是 numpy 数组产生不同的结果:当 theta 是 python 列表时,我的程序运行良好,但使用 theta = np.zeros((2, 1 )) 出了点问题,我的 theta 增长得非常快。
num_iter = 1500
alpha = 0.01
theta = [0, 0]
#theta = np.zeros((2, 1), dtype=np.float64)
print(theta)
def gradient_descent(x, y, theta, alpha, iteration):
m = y.size
i = 0
temp = np.zeros_like(theta, np.float64)
for i in range(iteration):
h = x @ theta
temp[0] = (alpha/m)*(np.sum(h - y))
temp[1] = (alpha/m)*(np.sum((h - y)*x[:,1]))
theta[0] -= temp[0]
theta[1] -= temp[1]
print("theta0 {}, theta1 {}, Cost {}".format(theta[0], theta[1], compute_cost(x, y, theta)))
return theta, J_history
theta = gradient_descent(X, y, theta, alpha, num_iter)
回答 theta 作为 numpy 数组
theta0 [5.663961], theta1 [63.36898425], Cost 15846739.108595487
theta0 [-495.73201075], theta1 [-4010.76967073], Cost 65114528414.94523
theta0 [31736.05800912], theta1 [259011.3427287], Cost 271418872442062.44
.
.
.
theta0 [nan], theta1 [nan], Cost nan
theta0 [nan], theta1 [nan], Cost nan
theta0 [nan], theta1 [nan], Cost nan
当 theta 是一个列表时回答
theta0 0.05839135051546392, theta1 0.6532884974555672, Cost 6.737190464870008
theta0 0.06289175271039384, theta1 0.7700097825599365, Cost 5.9315935686049555
.
.
.
theta0 -3.6298120050247746, theta1 1.166314185951815, Cost 4.483411453374869
theta0 -3.6302914394043593, theta1 1.166362350335582, Cost 4.483388256587725
解决方案
您的两个 theta 具有不同的形状:theta = [0,0]
具有形状 (1,2),但theta = np.zeros((2,1))
具有形状 (2,1)。因此,如果x
形状为 (n,),则x @ theta
第一个为 (1,n),第二个为 (n,1)。
例如,
t1 = [0,0]
t2 = np.zeros((2,1))
t3 = np.zeros((2,))
x = np.arange(6).reshape(3,2)
x @ t1
# array([0, 0, 0])
x @ t2
# array([[0.],
# [0.],
# [0.]])
x @ t3
# array([0, 0, 0])
更改theta = np.zeros((2,))
为(我认为)是一个快速解决方案。
推荐阅读
- reactjs - 创建一个不会在启动时获取后端 API 的资源
- scala - 范围内的隐式参数不应该用作宏生成代码的隐式参数吗?
- javascript - 等待 Promise Value 的行为
- javascript - 使用 Javascript 读取 Linux 系统文件?
- salesforce - 我能否在 Salesforce 中不让用户知道的情况下阻止将记录插入对象触发器中?
- javascript - 存储画布上下文并稍后编辑属性
- php - ConvertAPI 结果数组而不是 ResultFiles
- javascript - 碰撞检测的问题是 JavaScript
- excel - 需要匹配第一列中的条件,然后匹配第二列中的条件以显示该行信息
- field - Hapi HL7如何计算一个段或字段重复的次数