machine-learning - 反向传播(Andrew Ng 的 Cousera ML)梯度下降说明
问题描述
问题
请原谅我问 Coursera ML 课程的具体问题。希望做过couser的人能解答一下。
在Coursera ML Week 4 Multi-class Classification and Neural Networks assignment中,为什么权重(theta)梯度是加(加)导数而不是减法?
% Calculate the gradients of Weight2
% Derivative at Loss function J=L(Z) : dJ/dZ = (oi-yi)/oi(1-oi)
% Derivative at Sigmoid activation function dZ/dY = oi(1-oi)
delta_theta2 = oi - yi; % <--- (dJ/dZ) * (dZ/dY)
# Using +/plus NOT -/minus
Theta2_grad = Theta2_grad + <-------- Why plus(+)?
bsxfun(@times, hi, transpose(delta_theta2));
代码摘录
for i = 1:m
% i is training set index of X (including bias). X(i, :) is 401 data.
xi = X(i, :);
yi = Y(i, :);
% hi is the i th output of the hidden layer. H(i, :) is 26 data.
hi = H(i, :);
% oi is the i th output layer. O(i, :) is 10 data.
oi = O(i, :);
%------------------------------------------------------------------------
% Calculate the gradients of Theta2
%------------------------------------------------------------------------
delta_theta2 = oi - yi;
Theta2_grad = Theta2_grad + bsxfun(@times, hi, transpose(delta_theta2));
%------------------------------------------------------------------------
% Calculate the gradients of Theta1
%------------------------------------------------------------------------
% Derivative of g(z): g'(z)=g(z)(1-g(z)) where g(z) is sigmoid(H_NET).
dgz = (hi .* (1 - hi));
delta_theta1 = dgz .* sum(bsxfun(@times, Theta2, transpose(delta_theta2)));
% There is no input into H0, hence there is no theta for H0. Remove H0.
delta_theta1 = delta_theta1(2:end);
Theta1_grad = Theta1_grad + bsxfun(@times, xi, transpose(delta_theta1));
end
我认为这是减去导数。
解决方案
由于梯度是通过对所有训练示例的梯度进行平均来计算的,因此我们首先在遍历所有训练示例时“累积”梯度。我们通过对所有训练示例的梯度求和来做到这一点。因此,您用加号突出显示的行不是渐变更新步骤。(请注意,alpha 也不存在。)它可能在其他地方。它很可能在从 1 到 m 的循环之外。
另外,我不确定你什么时候会了解这个(我确定它在课程的某个地方),但你也可以对代码进行矢量化:)
推荐阅读
- c# - 在 Windows Server 2008 R2 中的共享和映射驱动器中搜索和读取文件很慢
- python - 连接到 PostgreSQL 时如何在 Python 脚本中隐藏密码?
- sql - Postgres,函数,在插入 vladate 数组值之前
- bash - Shell脚本进入Docker容器并执行命令,最终退出
- defaultdict - defaultdict - 第一个参数必须是可调用的或无
- python - Python食谱记录套接字服务器客户端记录器不发送数据包
- javascript - 如何汇总表格视图中的列值并将其显示在 Google 图表的标签中?
- office-js - 使用桌面客户端时未加载 Word Web 插件功能区图标
- raku - Cro::WebSocket::Client 不起作用
- sql - T-SQL - 计算循环模式