首页 > 解决方案 > 这个神经网络示例是我在看错误还是我不理解反向传播?

问题描述

这个模型是在两个地方使用一个 relu,还是通过对一层两侧的层进行矩阵乘法来计算梯度?

在这个简单神经网络的最后一层(下图)中,它w2通过对 y 预测进行矩阵乘法来计算最后一层的梯度 - y 和h_relu,我认为这只是层之间w1w2不是w2和之间y_pred

有问题的线靠近底部。它是grad_w2 = h_relu.t().mm(grad_y_pred)
我很困惑,因为我认为一切都应该按顺序前进并按顺序倒退。这个 relu 是在两个地方使用的吗?

这是对模型的视觉说明的尝试。

在此处输入图像描述

这个例子来自Pytorch 网站。它是页面上的第二个代码块。

grad_w2 = h_relu.t().mm(grad_y_pred)


import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

感谢您耐心地看着这个并试图为我解决这个问题。
如果您可以尝试在中间添加另一层 whieghts 和另一个可能有助于我理解的 relu。这就是我想要做的。

标签: numpymachine-learningpytorch

解决方案


考虑下图,它代表了所讨论的网络。反向传播的概念只是一种快速直观地将链式法则应用于复杂的操作序列以计算张量输出梯度的方法。通常我们感兴趣的是计算叶张量(不是从其他张量派生的张量)相对于损失或目标的梯度。下图中所有的叶子张量都用圆圈表示,损失用带有 L 标签的矩形表示。

在此处输入图像描述

使用反向图,我们可以遵循从 L 到 w1 和 w2 的路径,以确定我们需要哪些偏导数来计算 L wrt w1 和 w2 的梯度。为简单起见,我们将假设所有叶张量都是标量,以避免陷入向量和矩阵相乘的复杂性。

使用这种方法,L wrt w1 和 w2 的梯度是

在此处输入图像描述

在此处输入图像描述

需要注意的是,由于 w2 是叶张量,因此我们仅grad_w2在计算 dL/dw2 期间使用 dy/dw2(aka ),因为它不是从 L 到 w1 的路径的一部分。


推荐阅读