python - PyTorch 中手动权重更新中的梯度为零
问题描述
我正在尝试使用 AUTOGRAD 手动更新 MNIST 的权重来实现一个简单的神经网络,类似于此处给出的 AUTOGRAD 示例。这是我的代码:
import os
import sys
import torch
import torchvision
class Datasets:
"""Helper for extracting datasets."""
def __init__(self, root='data/', batch_size=25):
if not os.path.exists(root):
os.mkdir(root)
self.root = root
self.batch_size = batch_size
def get_mnist_loaders(self):
train_data = torchvision.datasets.MNIST(
root=self.root, train=True, download=True)
test_data = torchvision.datasets.MNIST(
root=self.root, train=False, download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_data, batch_size=self.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_data, batch_size=self.batch_size, shuffle=False)
return train_loader, test_loader
def create_batches(self, data, labels, batch_size):
return [(data[i:i+batch_size], labels[i:i+batch_size])
for i in range(0, len(data), max(1, batch_size))]
def train1():
dtype = torch.float
n_inputs = 28*28
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
batch_size = 200
n_epochs = 25
learning_rate = 0.01
test_step = 100
device = torch.device("cpu")
datasets = Datasets(batch_size=batch_size)
train_loader, test_loader = datasets.get_mnist_loaders()
def feed_forward(X):
x_shape = list(X.size())
X = X.view(x_shape[0], x_shape[1]*x_shape[2])
hidden1 = torch.mm(X, w1)
hidden1 += b1
hidden1 = hidden1.clamp(min=0)
hidden2 = torch.mm(hidden1, w2) + b2
hidden2 = hidden2.clamp(min=0)
logits = torch.mm(hidden2, w3) + b3
softmax = pytorch_softmax(logits)
return softmax
def accuracy(y_pred, y):
if list(y_pred.size()) != list(y.size()):
raise ValueError('Inputs have different shapes.')
total_correct = 0
total = 0
for i, (y1, y2) in enumerate(zip(y_pred, y)):
if y1 == y2:
total_correct += 1
total += 1
return total_correct / total
w1 = torch.randn(n_inputs, n_hidden1, device=device, dtype=dtype, requires_grad=True)
b1 = torch.nn.Parameter(torch.zeros(n_hidden1), requires_grad=True)
w2 = torch.randn(n_hidden1, n_hidden2, requires_grad=True)
b2 = torch.nn.Parameter(torch.zeros(n_hidden2), requires_grad=True)
w3 = torch.randn(n_hidden2, n_outputs, dtype=dtype, requires_grad=True)
b3 = torch.nn.Parameter(torch.zeros(n_outputs), requires_grad=True)
pytorch_softmax = torch.nn.Softmax(0)
pytorch_cross_entropy = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')
step = 0
for epoch in range(n_epochs):
batches = datasets.create_batches(train_loader.dataset.train_data,
train_loader.dataset.train_labels,
batch_size)
for x, y in batches:
step += 1
softmax = feed_forward(x.float())
vals, y_pred = torch.max(softmax, 1)
accuracy_ = accuracy(y_pred, y)
cross_entropy = pytorch_cross_entropy(softmax, y)
print(epoch, step, cross_entropy.item(), accuracy_)
cross_entropy.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w3 -= learning_rate * w3.grad
b1 -= learning_rate * b1.grad
b2 -= learning_rate * b2.grad
b3 -= learning_rate * b3.grad
w1.grad.zero_()
w2.grad.zero_()
w3.grad.zero_()
b1.grad.zero_()
b2.grad.zero_()
b3.grad.zero_()
if __name__ == '__main__':
train1()
然而,网络似乎没有训练。当我打印部分渐变(例如w1.grad.data[:10, :10]
)时,它们由零组成。我尝试使用weight.data
并weight.grad.data
更新权重并尝试删除该w.grad.zero_()
部分(即使它在示例中)但它没有帮助。这里有什么问题?
解决方案
这里有3个问题。
首先,您采用 softmax 的轴是错误的。它应该在最后一个轴上。
pytorch_softmax = torch.nn.Softmax(-1)
其次,您logits
的数字非常大。由此产生的导数是一个非常小的数字,因此您看到的是零。
tensor([[ -95782.0859, -30961.9023, -3614.0188, ..., -328240.6250,
-40818.2227, -160598.5469],
[-182128.5938, -76499.2969, 143654.6250, ..., -300924.1250,
-74291.3125, -109025.0391],
[-163018.4062, -71817.1172, -134466.0156, ..., -49884.1211,
-19183.3691, 116674.1406],
...,
[ 225013.4219, -37008.6484, 244807.2188, ..., -466822.8750,
63626.5625, -147146.0781],
[ 122045.7031, -90937.7344, 77259.1641, ..., -397063.9375,
-188736.9688, -78475.5000],
[ 23139.7578, -14914.8359, -205065.0625, ..., -65808.6562,
31458.8906, -11362.2344]], grad_fn=<AddBackward0>)
您可以做的几件事包括规范化您的数据、添加 BatchNorm、钳制等。我可以看到您的数据X
是一个张量,其值范围从 0 到 255。
第三,你不应该需要包装你的张量,nn.Parameter
因为它们只与nn.Module
类一起使用。
推荐阅读
- woocommerce - 在结账时用相机扫描信用卡 - Woocommerce & Stripe
- firebase - 在 Firebase 云功能中指定 RAM
- reactjs - React 路由器仅适用于 URL 中的 /#/,否则在 Ubuntu 服务器上会出现错误 404
- r - 添加循环以读取多个pdf页面时R中的语法错误
- python - 是否可以使用 Tkinter 创建适用于 Android 或 iOS 的应用程序?
- reactjs - 将正弦波包示例转换为反应组件
- c# - 如何在应用程序中使用多个 openGL 状态?
- javascript - 如何在当前 html 页面上显示来自 textarea 评论框的信息
- python - Python - Pandas - 忽略空格后的数据
- c# - UserPrincipal.SetPassword 抛出 0x80070005 (E_ACCESSDENIED)