首页 > 解决方案 > PyTorch 中手动权重更新中的梯度为零

问题描述

我正在尝试使用 AUTOGRAD 手动更新 MNIST 的权重来实现一个简单的神经网络,类似于此处给出的 AUTOGRAD 示例。这是我的代码:

import os
import sys

import torch
import torchvision
class Datasets:
    """Helper for extracting datasets."""

    def __init__(self, root='data/', batch_size=25):
        if not os.path.exists(root):
            os.mkdir(root)
        self.root = root
        self.batch_size = batch_size

    def get_mnist_loaders(self):
        train_data = torchvision.datasets.MNIST(
                root=self.root, train=True, download=True)
        test_data = torchvision.datasets.MNIST(
                root=self.root, train=False, download=True)


        train_loader = torch.utils.data.DataLoader(
                dataset=train_data, batch_size=self.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
                dataset=test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, test_loader

    def create_batches(self, data, labels, batch_size):
        return [(data[i:i+batch_size], labels[i:i+batch_size])
            for i in range(0, len(data), max(1, batch_size))]

def train1():
    dtype = torch.float
    n_inputs = 28*28
    n_hidden1 = 300
    n_hidden2 = 100
    n_outputs = 10
    batch_size = 200
    n_epochs = 25
    learning_rate = 0.01
    test_step = 100 
    device = torch.device("cpu")

    datasets = Datasets(batch_size=batch_size)
    train_loader, test_loader = datasets.get_mnist_loaders()

    def feed_forward(X):
        x_shape = list(X.size())
        X = X.view(x_shape[0], x_shape[1]*x_shape[2])
        hidden1 = torch.mm(X, w1)
        hidden1 += b1
        hidden1 = hidden1.clamp(min=0)
        hidden2 = torch.mm(hidden1, w2) + b2
        hidden2 = hidden2.clamp(min=0)
        logits = torch.mm(hidden2, w3) + b3
        softmax = pytorch_softmax(logits)
        return softmax

    def accuracy(y_pred, y):

        if list(y_pred.size()) != list(y.size()):
            raise ValueError('Inputs have different shapes.')

        total_correct = 0
        total = 0
        for i, (y1, y2) in enumerate(zip(y_pred, y)):
            if y1 == y2:
                total_correct += 1
            total += 1

        return total_correct / total

    w1 = torch.randn(n_inputs, n_hidden1, device=device, dtype=dtype, requires_grad=True)
    b1 = torch.nn.Parameter(torch.zeros(n_hidden1), requires_grad=True)

    w2 = torch.randn(n_hidden1, n_hidden2, requires_grad=True)
    b2 = torch.nn.Parameter(torch.zeros(n_hidden2), requires_grad=True)

    w3 = torch.randn(n_hidden2, n_outputs, dtype=dtype, requires_grad=True)
    b3 = torch.nn.Parameter(torch.zeros(n_outputs), requires_grad=True)

    pytorch_softmax = torch.nn.Softmax(0)
    pytorch_cross_entropy = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    step = 0
    for epoch in range(n_epochs):
        batches = datasets.create_batches(train_loader.dataset.train_data,
                                          train_loader.dataset.train_labels,
                                          batch_size)
        for x, y in batches:
            step += 1

            softmax = feed_forward(x.float())
            vals, y_pred = torch.max(softmax, 1)
            accuracy_ = accuracy(y_pred, y)
            cross_entropy = pytorch_cross_entropy(softmax, y)

            print(epoch, step, cross_entropy.item(), accuracy_)

            cross_entropy.backward()

            with torch.no_grad():
                w1 -= learning_rate * w1.grad
                w2 -= learning_rate * w2.grad
                w3 -= learning_rate * w3.grad

                b1 -= learning_rate * b1.grad
                b2 -= learning_rate * b2.grad
                b3 -= learning_rate * b3.grad

                w1.grad.zero_()
                w2.grad.zero_()
                w3.grad.zero_()

                b1.grad.zero_()
                b2.grad.zero_()
                b3.grad.zero_()

if __name__ == '__main__':
    train1()

然而,网络似乎没有训练。当我打印部分渐变(例如w1.grad.data[:10, :10])时,它们由零组成。我尝试使用weight.dataweight.grad.data更新权重并尝试删除该w.grad.zero_()部分(即使它在示例中)但它没有帮助。这里有什么问题?

标签: pythonpytorch

解决方案


这里有3个问题。

首先,您采用 softmax 的轴是错误的。它应该在最后一个轴上。

pytorch_softmax = torch.nn.Softmax(-1)

其次,您logits的数字非常大。由此产生的导数是一个非常小的数字,因此您看到的是零。

tensor([[ -95782.0859,  -30961.9023,   -3614.0188,  ..., -328240.6250,
          -40818.2227, -160598.5469],
        [-182128.5938,  -76499.2969,  143654.6250,  ..., -300924.1250,
          -74291.3125, -109025.0391],
        [-163018.4062,  -71817.1172, -134466.0156,  ...,  -49884.1211,
          -19183.3691,  116674.1406],
         ...,
        [ 225013.4219,  -37008.6484,  244807.2188,  ..., -466822.8750,
           63626.5625, -147146.0781],
        [ 122045.7031,  -90937.7344,   77259.1641,  ..., -397063.9375,
         -188736.9688,  -78475.5000],
        [  23139.7578,  -14914.8359, -205065.0625,  ...,  -65808.6562,
           31458.8906,  -11362.2344]], grad_fn=<AddBackward0>)

您可以做的几件事包括规范化您的数据、添加 BatchNorm、钳制等。我可以看到您的数据X是一个张量,其值范围从 0 到 255。

第三,你不应该需要包装你的张量,nn.Parameter因为它们只与nn.Module类一起使用。


推荐阅读