r - 为什么 Adam 优化无法在线性回归中收敛?
问题描述
我正在研究亚当优化器。这是一个玩具问题。在 R 中,我生成了一些人工数据:
Y = c0 + c1 * x1 + c2 * x2 + 噪声
在上面的等式中,x1、x2 和噪声是我在 R 中生成的正常随机数,theta = [c0, c1, c2] 是我尝试用 Adam 优化器估计的参数。对于这个简单的回归问题,我可以使用分析方法来确定 theta 参数,即下面我的 R 代码中的 k。
关于亚当算法,我使用了这个网站的公式
我在这个参数研究中改变了步长 eta。Adam 算法的最终 theta 与我的 R 代码中 的解析解k不同。 我检查了我的代码很多次。我逐行运行代码,无法理解为什么 Adam 算法不能收敛。
补充:我将算法更改为 AMSGrad。在这种情况下,它的表现比 Adam 好。但是,AMSGrad 不收敛。
rm(list = ls())
n=500
x1=rnorm(n,mean=6,sd=1.6)
x2=rnorm(n,mean=4,sd=2.5)
X=cbind(x1,x2)
A=as.matrix(cbind(intercept=rep(1,n),x1,x2))
Y=-20+51*x1-15*x2+rnorm(n,mean=0,sd=2);
k=solve(t(A)%*%A,t(A)%*%Y) # k is the parameters determined by analytical method
MSE=sum((A%*%k-Y)^2)/(n);
iterations=4000 # total number of steps
epsilon = 0.0001 # set precision
eta=0.04 # step size
beta1=0.9
beta2=0.999
t1=integer(iterations)
t2=matrix(0,iterations,3)
t3=integer(iterations)
epsilon1=1E-8 # small number defined for numerical computation
X=as.matrix(X)# convert data table X into a matrix
N=dim(X)[1] # total number of observations
X=as.matrix(cbind(intercept=rep(1,length(N)),X))# add a column of ones to represent intercept
np=dim(X)[2] # number of parameters to be determined
theta=matrix(rnorm(n=np,mean=0,sd=2),1,np) # Initialize theta:1 x np matrix
m_i=matrix(0,1,np) # initialization, zero vector
v_i=matrix(0,1,np) # initialization, zero vector
for(i in 1:iterations){
error=theta%*%t(X)-t(Y) # error = (theta * x' -Y'). Error is a 1xN row vector;
grad=1/N*error%*%X # Gradient grad is 1 x np vector
m_i=beta1*m_i+(1-beta1)*grad # moving average of gradient, 1 x np vector
v_i=beta2*v_i+(1-beta2)*grad^2 # moving average of squared gradients, 1 x np vector
# corrected moving averages
m_corrected=m_i/(1-beta1^i)
v_corrected=v_i/(1-beta2^i)
d_theta=eta/(sqrt(v_corrected)+epsilon1)*m_corrected
theta=theta-d_theta
L=sqrt(sum((d_theta)^2)) # calculating the L2 norm
t1[i]=L # record the L2 norm in each step
if ((is.infinite(L))||(is.nan(L))) {
print("Learning rate is too large. Lowering the rate may help.")
break
}
else if (L<=epsilon) {
print("Algorithm convergence is reached.")
break # checking whether convergence is obtained or not
}
# if (i==1){
# browser()
# }
}
plot(t1,type="l",ylab="norm",lwd=3,col=rgb(0,0,1))
k
theta
解决方案
推荐阅读
- r - R - 如何获得 2 个阶跃函数的差值/总和?
- elasticsearch - Cloudwatch 到 Elasticsearch 在推送到 ES 之前解析/标记日志事件
- java - 使用 Java 解析器在 XML 属性中保留 /t 和 /n
- c - 在C中的字符串中返回置换
- sql - 多对多?
- javascript - 在html页面中编写javascript出现在页面中
- android - 如何在 Android Test 接口中对静态方法进行单元测试?
- rollup - 使用汇总 js 将多个 es6 类组合成单个库
- python - Tensorflow 图像分割权重未更新
- sql - 查找包含特定表名的所有视图