首页 > 解决方案 > 为什么 Adam 优化无法在线性回归中收敛?

问题描述

我正在研究亚当优化器。这是一个玩具问题。在 R 中,我生成了一些人工数据:

Y = c0 + c1 * x1 + c2 * x2 + 噪声

在上面的等式中,x1、x2 和噪声是我在 R 中生成的正常随机数,theta = [c0, c1, c2] 是我尝试用 Adam 优化器估计的参数。对于这个简单的回归问题,我可以使用分析方法来确定 theta 参数,即下面我的 R 代码中的 k。

关于亚当算法,我使用了这个网站的公式

概述:亚当

我在这个参数研究中改变了步长 eta。Adam 算法的最终 theta 与我的 R 代码中 的解析解k不同。规范 我检查了我的代码很多次。我逐行运行代码,无法理解为什么 Adam 算法不能收敛。

补充:我将算法更改为 AMSGrad。在这种情况下,它的表现比 Adam 好。但是,AMSGrad 不收敛。 AMSG研究生院

rm(list = ls())  

n=500
x1=rnorm(n,mean=6,sd=1.6)
x2=rnorm(n,mean=4,sd=2.5)

X=cbind(x1,x2)
A=as.matrix(cbind(intercept=rep(1,n),x1,x2))
Y=-20+51*x1-15*x2+rnorm(n,mean=0,sd=2);


k=solve(t(A)%*%A,t(A)%*%Y) # k is the parameters determined by analytical method
MSE=sum((A%*%k-Y)^2)/(n);

iterations=4000 # total number of steps
epsilon = 0.0001 # set precision
eta=0.04 # step size
beta1=0.9
beta2=0.999

t1=integer(iterations)
t2=matrix(0,iterations,3)
t3=integer(iterations)

epsilon1=1E-8 # small number defined for numerical computation
X=as.matrix(X)# convert data table X into a matrix
N=dim(X)[1] # total number of observations
X=as.matrix(cbind(intercept=rep(1,length(N)),X))# add a column of ones to represent intercept
np=dim(X)[2] # number of parameters to be determined
theta=matrix(rnorm(n=np,mean=0,sd=2),1,np) # Initialize theta:1 x np matrix
m_i=matrix(0,1,np) # initialization, zero vector
v_i=matrix(0,1,np) # initialization, zero vector

for(i in 1:iterations){
  error=theta%*%t(X)-t(Y) # error = (theta * x' -Y'). Error is a 1xN row vector;
  grad=1/N*error%*%X # Gradient grad is 1 x np vector
  m_i=beta1*m_i+(1-beta1)*grad # moving average of gradient, 1 x np vector
  v_i=beta2*v_i+(1-beta2)*grad^2 # moving average of squared gradients, 1 x np vector
  # corrected moving averages
  m_corrected=m_i/(1-beta1^i)
  v_corrected=v_i/(1-beta2^i)
  d_theta=eta/(sqrt(v_corrected)+epsilon1)*m_corrected
  theta=theta-d_theta
  L=sqrt(sum((d_theta)^2)) # calculating the L2 norm
  t1[i]=L # record the L2 norm in each step
  if ((is.infinite(L))||(is.nan(L))) {
    print("Learning rate is too large. Lowering the rate may help.")
    break
  }
  else if (L<=epsilon) {
    print("Algorithm convergence is reached.")
    break # checking whether convergence is obtained or not
  }
  # if (i==1){
  #   browser()
  # }
}

plot(t1,type="l",ylab="norm",lwd=3,col=rgb(0,0,1))
k
theta

标签: ralgorithmmathematical-optimization

解决方案


推荐阅读