python - 基于 Pytorch 的 Resnet18 在 CIFAR100 上实现了低准确率
问题描述
我正在 CIFAR100 数据集上训练 resnet18。大约 50 次迭代后,验证准确率收敛在 34% 左右。而训练准确率几乎达到了100%。
我怀疑它有点过拟合,所以我应用了像RandomHorizontalFlip
and这样的数据增强RandomRotation
,这使得验证收敛在 40% 左右。
我还尝试了衰减学习率[0.1, 0.03, 0.01, 0.003, 0.001]
,每 50 次迭代后衰减。衰减的学习率似乎并没有提高性能。
听说 CIFAR100 上的 Resnet 可以达到 70%~80% 的准确率。我还能应用什么技巧?或者我的实施有什么问题吗?CIFAR10 上的相同代码可以达到 80% 左右的准确率。
我的整个培训和评估代码如下:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, RandomRotation, Normalize
from torchvision.datasets import CIFAR10, CIFAR100
import os
from datetime import datetime
import matplotlib.pyplot as plt
def draw_loss_curve(histories, legends, save_dir):
os.makedirs(save_dir, exist_ok=True)
for key in histories[0][0].keys():
if key != "epoch":
plt.figure()
plt.title(key)
for history in histories:
x = [h["epoch"] for h in history]
y = [h[key] for h in history]
# plt.ylim(ymin=0, ymax=3.0)
plt.plot(x, y)
plt.legend(legends)
plt.savefig(os.path.join(save_dir, key + ".png"))
def cal_acc(out, label):
batch_size = label.shape[0]
pred = torch.argmax(out, dim=1)
num_true = torch.nonzero(pred == label).shape[0]
acc = num_true / batch_size
return torch.tensor(acc)
class LrManager(optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, lrs):
def f(epoch):
rate = 1
for k in sorted(lrs.keys()):
if epoch >= k:
rate = lrs[k]
else:
break
return rate
super(LrManager, self).__init__(optimizer, f)
def main(cifar=100, epochs=250, batches_show=100):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print("warning: CUDA is not available, using CPU instead")
dataset_cls = CIFAR10 if cifar == 10 else CIFAR100
dataset_train = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=True,
transform=Compose([RandomHorizontalFlip(), RandomRotation(15), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
dataset_val = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=False,
transform=Compose([ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
loader_train = DataLoader(dataset_train, batch_size=128, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=128, shuffle=True)
model = resnet18(pretrained=False, num_classes=cifar).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
lr_scheduler = LrManager(optimizer, {0: 1.0, 50: 0.3, 100: 0.1, 150: 0.03, 200: 0.01})
criterion = nn.CrossEntropyLoss()
history = []
model.train()
for epoch in range(epochs):
print("------------------- TRAINING -------------------")
loss_train = 0.0
running_loss = 0.0
acc_train = 0.0
running_acc = 0.0
for batch, data in enumerate(loader_train, 1):
img, label = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
pred = model(img)
loss = criterion(pred, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
loss_train += loss.item()
acc = cal_acc(pred, label)
running_acc += acc.item()
acc_train += acc.item()
if batch % batches_show == 0:
print(f"epoch: {epoch}, batch: {batch}, loss: {running_loss/batches_show:.4f}, acc: {running_acc/batches_show:.4f}")
running_loss = 0.0
running_acc = 0.0
loss_train = loss_train / batch
acc_train = acc_train / batch
lr_scheduler.step()
print("------------------- EVALUATING -------------------")
with torch.no_grad():
running_acc = 0.0
for batch, data in enumerate(loader_val, 1):
img, label = data[0].to(device), data[1].to(device)
pred = model(img)
acc = cal_acc(pred, label)
running_acc += acc.item()
acc_val = running_acc / batch
print(f"epoch: {epoch}, acc_val: {acc_val:.4f}")
history.append({"epoch": epoch, "loss_train": loss_train, "acc_train": acc_train, "acc_val": acc_val})
draw_loss_curve([history], legends=[f"resnet18-CIFAR{cifar}"], save_dir=f"history/resnet18-CIFAR{cifar}[{datetime.now()}]")
if __name__ == '__main__':
main()
解决方案
Resnet18 来自torchvision.models
它的 ImageNet 实现。因为 ImageNet 的样本比 CIFAR10/100 (32x32) 大得多 (224x224),所以第一层旨在积极地对输入进行下采样(“stem 网络”)。这会导致在小型 CIFAR10/100 图像上丢失许多有价值的信息。
为了在 CIFAR10 上达到良好的准确性,作者使用了不同的网络结构,如原始论文中所述: https ://arxiv.org/pdf/1512.03385.pdf 并在本文中解释: https ://towardsdatascience.com/resnets-for-cifar -10-e63e900524e0
您可以从此 repo 下载 resnet fo CIFAR10:https ://github.com/akamaster/pytorch_resnet_cifar10
推荐阅读
- node.js - 如何在 Telegraf 中处理长消息
- c# - C# Windows 窗体 LiveCharts GeoMap
- python - 如何在 Python 中从迭代器创建字符串?
- r - R 中的 MANOVA。如何解决此错误:“1L 中的错误:object$rank”
- python - 导入暗网给出错误:未定义名称“DARKNET_FORCE_CPU”
- excel - 如果时间在下午 4 点之后,则显示此单元格,如果错误则显示此单元格
- android - “bool 类型的无效参数 false。” 当颤振应用程序启动时
- reactjs - 如何获取异步/等待函数返回的值?
- c++ - 迭代参数包中的向量
- swift - 避免 Xcode 预览窗口中的 SSL 错误