numpy - 这个神经网络示例是我在看错误还是我不理解反向传播?
问题描述
这个模型是在两个地方使用一个 relu,还是通过对一层两侧的层进行矩阵乘法来计算梯度?
在这个简单神经网络的最后一层(下图)中,它w2
通过对 y 预测进行矩阵乘法来计算最后一层的梯度 - y 和h_relu
,我认为这只是层之间w1
而w2
不是w2
和之间y_pred
有问题的线靠近底部。它是grad_w2 = h_relu.t().mm(grad_y_pred)
。
我很困惑,因为我认为一切都应该按顺序前进并按顺序倒退。这个 relu 是在两个地方使用的吗?
这是对模型的视觉说明的尝试。
这个例子来自Pytorch 网站。它是页面上的第二个代码块。
grad_w2 = h_relu.t().mm(grad_y_pred)
import torch
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)
# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2)
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h)
# Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
感谢您耐心地看着这个并试图为我解决这个问题。
如果您可以尝试在中间添加另一层 whieghts 和另一个可能有助于我理解的 relu。这就是我想要做的。
解决方案
考虑下图,它代表了所讨论的网络。反向传播的概念只是一种快速直观地将链式法则应用于复杂的操作序列以计算张量输出梯度的方法。通常我们感兴趣的是计算叶张量(不是从其他张量派生的张量)相对于损失或目标的梯度。下图中所有的叶子张量都用圆圈表示,损失用带有 L 标签的矩形表示。
使用反向图,我们可以遵循从 L 到 w1 和 w2 的路径,以确定我们需要哪些偏导数来计算 L wrt w1 和 w2 的梯度。为简单起见,我们将假设所有叶张量都是标量,以避免陷入向量和矩阵相乘的复杂性。
使用这种方法,L wrt w1 和 w2 的梯度是
和
需要注意的是,由于 w2 是叶张量,因此我们仅grad_w2
在计算 dL/dw2 期间使用 dy/dw2(aka ),因为它不是从 L 到 w1 的路径的一部分。
推荐阅读
- ios - 如何在 CoreData 中存储 UIImageView 的位置、大小和旋转以填充不同的 ViewController?
- c# - 使用 Newtonsoft 将带有表达式过滤器的 JSON 路径应用于非数组 JSON 元素
- sql - 根据来自另一个表的传入列名从一个表中选择列
- split - 如何先拆分字符串,然后在机器人框架中使用get Substring
- python - Pandas 数据框 csv 导出中的空白未修剪
- python - 如何对有时间的python列表进行下采样
- visual-studio-code - 无法在 VS Code 中将 jupyter notebook 转换为 python 脚本
- excel - 使用高级 Excel 进行数据处理
- php - 使用 laravel voyager 管理面板上传大文件不起作用
- linux - 如何从 bash 命令获取数字(四舍五入)?(从字符串中提取)