首页 > 解决方案 > PyTorch 的 torch.autograd.grad 中 grad_outputs 的含义

问题描述

我无法理解 中grad_outputs选项的概念含义torch.autograd.grad

文档说:

grad_outputs应该是包含雅可比向量积中的“向量”的长度匹配输出序列,通常是每个输出的预计算梯度。如果输出不是require_grad,那么梯度可以是None)。

我觉得这个描述很神秘。Jacobian-vector product到底是什么意思?我知道雅可比矩阵是什么,但不确定它们在这里指的是什么产品:元素方面的产品,矩阵产品,还是别的?我无法从下面的示例中看出。

为什么引号中的“向量” ?实际上,在下面的示例中,当是向量时出现错误grad_outputs,但当它是矩阵时却没有。

>>> x = torch.tensor([1.,2.,3.,4.], requires_grad=True)
>>> y = torch.outer(x, x)

为什么我们观察到以下输出;它是如何计算的?

>>> y
tensor([[ 1.,  2.,  3.,  4.],
        [ 2.,  4.,  6.,  8.],
        [ 3.,  6.,  9., 12.],
        [ 4.,  8., 12., 16.]], grad_fn=<MulBackward0>)

>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))
(tensor([20., 20., 20., 20.]),)

但是,为什么会出现这个错误?

>>> torch.autograd.grad(y, x, grad_outputs=torch.ones_like(x))  

RuntimeError: Mismatch in shape:grad_output[0]形状为torch.Size([4])output[0]形状为torch.Size([4, 4])

标签: pytorchautograd

解决方案


如果我们以您的示例为例,我们将具有f输入xshape(n,)和输出y = f(x)shape 的功能(n, n)。输入被描述为列向量[x_i]_i for i ∈ [1, n],并被f(x)定义为矩阵[y_jk]_jk = [x_j*x_k]_jk for j, k ∈ [1, n]²

计算输出相对于输入的梯度通常很有用(或者有时 wrt 的参数f,这里没有)。不过,在更一般的情况下,我们希望计算dL/dx而不仅仅是dy/dxdL/dx的偏导数是L计算自y, wrt x

计算图如下所示:

x.grad = dL/dx <-------   dL/dy y.grad
                dy/dx
       x       ------->    y = x*xT

然后,如果我们看dL/dx,也就是通过链式法则等于dL/dy*dy/dx。看看 的接口,我们有torch.autograd.grad以下对应关系:

  • outputs<-> y,
  • inputs<->x
  • grad_outputs<-> dL/dy

查看形状:dL/dx应该具有与x(dL/dx可以称为 的“梯度” x) 相同的形状,而dy/dx雅可比矩阵将是 3 维的。另一方面dL/dy,作为输入梯度的 ,应该与输出具有相同的形状,y形状。

我们要计算dL/dx = dL/dy*dy/dx. 如果我们仔细观察,我们有

dy/dx = [dy_jk/dx_i]_ijk for i, j, k ∈ [1, n]³

所以,

dL/dx = [dL/d_x_i]_i, i ∈ [1,n]
      = [sum(dL/dy_jk * d(y_jk)/dx_i over j, k ∈ [1, n]²]_i, i ∈ [1,n]

回到您的示例,这意味着对于给定的i ∈ [1, n]: dL/dx_i = sum(dy_jk/dx_i) over j, k ∈ [1,n]²。并且dy_jk/dx_i = f(x_j*x_k)/dx_i将等于x_jif i = kx_kifi = j2*x_iif i = j = k(因为平方x_i)。据说矩阵y是对称的......所以结果归结为2*sum(x_i) over i ∈ [1, n]

这意味着dL/dx是列向量[2*sum(x)]_i for i ∈ [1, n]

>>> 2*x.sum()*torch.ones_like(x)
tensor([20., 20., 20., 20.])

回过头来看看这个其他图形示例,这里在后面添加一个额外的操作y

  x   ------->  y = x*xT  -------->  z = y²

如果您查看此图上的反向传递,您有:

dL/dx <-------   dL/dy    <--------  dL/dz
        dy/dx              dz/dy 
  x   ------->  y = x*xT  -------->  z = y²

dL/dx = dL/dy*dy/dx = dL/dz*dz/dy*dy/dx实际上,它分两个连续的步骤计算:,dL/dy = dL/dz*dz/dy然后dL/dx = dL/dy*dy/dx


推荐阅读