matlab - matlab中的小批量实现
问题描述
我实现了一个小批量随机梯度下降算法,然后将它与一个小的 nn 一起用于分类问题,但四舍五入后所有预测都为零。这要么意味着模型没用,要么我的实现中存在错误。
0 | 1 | |
---|---|---|
0 | 122 | 0 |
1 | 78 | 0 |
我无法确定错误在哪里。
%% Model
rng(1024);
% initialize weights and bias
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
end
end
% figure
fig = plot(1:iter, loss_vec);
hold on
plot(1:iter, tloss_vec)
title("cost when \eta = ", eta)
xlabel("iteration")
ylabel("cost")
legend("Train", "Val")
% saveas(fig,"eta="+eta+"_method"+method+".png")
hold off
解决方案
推荐阅读
- rust - 无法将关闭传递给“Hyper::service_fn”
- css - 如何在 asp razor 中更改用户的负载 css
- sql-server - 从没有 CURSOR 的表中减去查询结果?
- c++ - 如何将条件放入char数组中?
- html - 为什么 angular2 ngIf 在 HTML 文件中不起作用?
- pyspark - 使用 pyspark 重新分区失败并出现错误
- java - JavaFX 非法访问错误
- python - 我们不能将对象传递给写函数以将一个文件中的数据添加到python中的另一个文件吗?
- c# - 尽管 SingleOrDefault() 将返回 null,但 EF Core LINQ 查询何时导致 null 合并?
- typescript - 'handleListKeyDown',缺少返回类型注释