首页 > 解决方案 > matlab中的小批量实现

问题描述

我实现了一个小批量随机梯度下降算法,然后将它与一个小的 nn 一起用于分类问题,但四舍五入后所有预测都为零。这要么意味着模型没用,要么我的实现中存在错误。

0 1
0 122 0
1 78 0

我无法确定错误在哪里。

%% Model
rng(1024);

% initialize weights and bias
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5); 
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1); 
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);

eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);

loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);

for it = 1:iter

    % mini-batch method
    batch_size = 50;
    rand_idx = randperm(num_data);
    rand_idx = reshape(rand_idx,[],num_data/batch_size);
    
    for idx = rand_idx
        % forward pass
        a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
        a3 = activate(a2,W3,b3);
        a4 = activate(a3,W4,b4);
        a5 = activate(a4,W5,b5);
        
        % backward pass (gradient)
        delta5 = a5.*(1-a5).*(a5-label(idx));
        delta4 = a4.*(1-a4).*(W5'*delta5);
        delta3 = a3.*(1-a3).*(W4'*delta4);
        delta2 = a2.*(1-a2).*(W3'*delta3);
        
        % update weights and bias
        W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
        W3 = W3 - 1/length(idx)*eta*delta3*a2';
        W4 = W4 - 1/length(idx)*eta*delta4*a3';
        W5 = W5 - 1/length(idx)*eta*delta5*a4';
        b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
        b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
        b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
        b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
        
        % compute train loss and test loss
        loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
        tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
    end
end

% figure
fig = plot(1:iter, loss_vec);
hold on
plot(1:iter, tloss_vec)
title("cost when \eta = ", eta)
xlabel("iteration")
ylabel("cost")
legend("Train", "Val")
% saveas(fig,"eta="+eta+"_method"+method+".png")
hold off

标签: matlaboptimizationdeep-learningneural-network

解决方案


推荐阅读