python - 当我将 weight_decay 参数添加到 PyTorch 中的优化器时,我的训练精度保持在 10%。我正在使用 CIFAR10 数据集和 LeNet CNN 模型
问题描述
我正在 LeNet CNN 模型上训练 CIFAR10 数据集。我在 Google Colab 上使用 PyTorch。只有当我使用带有 model.parameters() 作为唯一参数的 Adam 优化器时,代码才会运行。但是当我改变我的优化器或使用 weight_decay 参数时,精度在所有时期都保持在 10%。我无法理解它发生的原因。
# CNN Model - LeNet
class LeNet_ReLU(nn.Module):
def __init__(self):
super().__init__()
self.cnn_model = nn.Sequential(nn.Conv2d(3,6,5),
nn.ReLU(),
nn.AvgPool2d(2, stride=2),
nn.Conv2d(6,16,5),
nn.ReLU(),
nn.AvgPool2d(2, stride=2))
self.fc_model = nn.Sequential(nn.Linear(400, 120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10))
def forward(self, x):
x = self.cnn_model(x)
x = x.view(x.size(0), -1)
x = self.fc_model(x)
return x
# Importing dataset and creating dataloader
batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
transform=transforms.ToTensor())
trainloader = utils_data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
transform=transforms.ToTensor())
testloader = utils_data.DataLoader(testset, batch_size=batch_size, shuffle=False)
# Creating instance of the model
net = LeNet_ReLU()
# Evaluation function
def evaluation(dataloader):
total, correct = 0, 0
for data in dataloader:
inputs, labels = data
outputs = net(inputs)
_, pred = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (pred==labels).sum().item()
return correct/total * 100
# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(net.parameters(), weight_decay = 0.9)
# Model training
loss_epoch_arr = []
max_epochs = 16
for epoch in range(max_epochs):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
outputs = net(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
opt.step()
opt.zero_grad()
loss_epoch_arr.append(loss.item())
print('Epoch: %d/%d, Test acc: %0.2f, Train acc: %0.2f'
% (epoch,max_epochs, evaluation(testloader), evaluation(trainloader)))
plt.plot(loss_epoch_arr)
解决方案
权重衰减机制为高值权重设置了惩罚,即通过将权重的总和乘以weight_decay
您给出的参数来限制权重具有相对较小的值。这可以看作是一个二次正则化项。
当传递较大weight_decay
的值时,您可能会过于严格地限制您的网络并阻止它学习,这可能是它具有 10% 的准确率的原因,这与非学习有关并且只是猜测答案(因为您收到了 10 个课程10% 的 acc,当输出根本不是您输入的函数时)。
解决方案是在该区域使用不同的值、训练weight_decay
或1e-4
其他值。请注意,当您达到接近零的值时,您应该得到更接近初始训练的结果,而不使用权重衰减。
希望有帮助。
推荐阅读
- java - Mockito 在调用 doCallRealMethod 时抛出 NullpointerException
- javascript - React - 访问导出的常量
- firebase-storage - 将获取的数据存储到状态组件中
- python-3.x - 如何将 python matplotlib.pyplot 图例标记更改为 1、2、3 之类的序列号,而不是形状或字符?
- java - Firebase 中的多态性
- node.js - 程序化 Webpack & Jest (ESM):无法解析没有“.js”文件扩展名的模块
- flutter - 如何修改整个应用程序的 Scaffold 小部件
- python - 使用 DatetimeIndex 重新索引
- r - R: model.frame.default(formula = class ~ step + type + amount + :) 中的错误:对象不是矩阵
- python - 三模型django相关的prefetch怎么做?嵌套预取django