python - 线性回归梯度
问题描述
我有非常基本的线性回归样本。下面的实现(没有正则化)
class Learning:
def assume(self, weights, x):
return np.dot(x, np.transpose(weights))
def cost(self, weights, x, y, lam):
predict = self.assume(weights, x) \
.reshape(len(x), 1)
val = np.sum(np.square(predict - y), axis=0)
assert val is not None
assert val.shape == (1,)
return val[0] / 2 * len(x)
def grad(self, weights, x, y, lam):
predict = self.assume(weights, x)\
.reshape(len(x), 1)
val = np.sum(np.multiply(
x, (predict - y)), axis=0)
assert val is not None
assert val.shape == weights.shape
return val / len(x)
我想检查渐变,它是否有效,使用scipy.optimize
.
learn = Learning()
INPUTS = np.array([[1, 2],
[1, 3],
[1, 6]])
OUTPUTS = np.array([[3], [5], [11]])
WEIGHTS = np.array([1, 1])
t_check_grad = scipy.optimize.check_grad(
learn.cost, learn.grad, WEIGHTS,INPUTS, OUTPUTS, 0)
print(t_check_grad)
# Output will be 73.2241602235811!!!
我从头到尾手动检查了所有计算。它实际上是正确的实现。但在输出中我看到了非常大的差异!是什么原因?
解决方案
在您的成本函数中,您应该返回
val[0] / (2 * len(x))
而不是val[0] / 2 * len(x)
. 然后你会有
print(t_check_grad)
# 1.20853633278e-07
推荐阅读
- spring - Spring Boot Security:如何在 Spring Boot 中的 CSRF 之前运行身份验证过滤器?
- java - 为什么我的 Mongo 聚合不能在嵌套文档上正常工作?
- css - 输入html的多个点
- javascript - 如何通过布尔属性和整数属性对对象数组进行排序,如果该整数属性为 0,则保持接近结尾?
- spring - spring data jpa 无法使用原生 sql 更新数据
- sql - 可以将聚合函数与连接一起使用吗?
- python - 从列表中获取所有不包含相同元素的对
- flutter - Flutter Web - 无法编译,包中的双重导入问题
- json - VS Code 中 json 模式的相对文件路径
- apache-spark - 刚刚安装了火花和斯卡拉。返回不支持的类文件主要版本:58