首页 > 解决方案 > 手动计算梯度与 PyTorch 链式法则的数值差异

问题描述

我正在使用这个非常简单的线性网络的公式手动计算梯度,并带有 MSE 损失。

然后,我与 PyTorch 计算的梯度进行比较,并使用 PyTorch 的allclose函数检查 PyTorch 是否正确计算了梯度(即手动计算的梯度和 pytorch 之间的相对差异足够小)。
由于公式正确,所有测试都应通过。但对于某些种子来说,它只是没有。

所以很明显 PyTorch 没有做错任何事情,但由于公式是正确的,它必须来自公式中的一些数值不稳定性问题。

import torch

class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.linear = torch.nn.Linear(10, 1)


    def forward(self, x):
        return self.linear(x)

loss = torch.nn.MSELoss()

for i in range(0, 1000):
    torch.manual_seed(i)
    X = torch.randn(100, 10)
    y = torch.randn(100, 1)





    model=Network()
    model.train()
    optimizer=torch.optim.SGD(model.parameters(),lr=1.)
    optimizer.zero_grad()
    output = loss(model(X), y)
    output.backward()

    torch_grads=[]
    for p in model.parameters():
        torch_grads.append(p.grad.detach().data)



    #df/dW = (-2X.T*y+2*X.T*b+2*X.T*X*W)/nsamples 
    #df/db = (2*b-2*y+2*W.T*X.T).mean() (the mean comes from implicit broadcasting of b)

    theory_grad_w = (-2 * torch.matmul(X.t(), y)
                     +2 * torch.matmul(torch.t(X), torch.ones((X.shape[0], 1)))* list(model.parameters())[1]
                     +2 * torch.matmul(torch.matmul(X.t(), X), list(model.parameters())[0].t())
                     ) / float(X.shape[0])

    theory_grad_w = theory_grad_w.t()


    theory_grad_b = torch.mean(2 * list(model.parameters())[1]- 2 * y+ 2 * torch.matmul((list(model.parameters())[0]), torch.t(X)))

    theory_grads = [theory_grad_w, theory_grad_b]

    b=all([torch.allclose(u, d) for u, d in zip(torch_grads, theory_grads)])
    if not(b):

      print("i=%s, pass=%s"%(i, b))

观察到的数值不稳定性的来源是什么以及如何处理它们,以便测试始终通过。这只是对操作进行不同排序的问题吗?

标签: pythondeep-learningpytorchbackpropagation

解决方案


推荐阅读