首页 > 解决方案 > 基于 Pytorch 的 Resnet18 在 CIFAR100 上实现了低准确率

问题描述

我正在 CIFAR100 数据集上训练 resnet18。大约 50 次迭代后,验证准确率收敛在 34% 左右。而训练准确率几乎达到了100%。

我怀疑它有点过拟合,所以我应用了像RandomHorizontalFlipand这样的数据增强RandomRotation,这使得验证收敛在 40% 左右。

我还尝试了衰减学习率[0.1, 0.03, 0.01, 0.003, 0.001],每 50 次迭代后衰减。衰减的学习率似乎并没有提高性能。

听说 CIFAR100 上的 Resnet 可以达到 70%~80% 的准确率。我还能应用什么技巧?或者我的实施有什么问题吗?CIFAR10 上的相同代码可以达到 80% 左右的准确率。

我的整个培训和评估代码如下:

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, RandomRotation, Normalize
from torchvision.datasets import CIFAR10, CIFAR100
import os
from datetime import datetime
import matplotlib.pyplot as plt


def draw_loss_curve(histories, legends, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    for key in histories[0][0].keys():
        if key != "epoch":
            plt.figure()
            plt.title(key)
            for history in histories:
                x = [h["epoch"] for h in history]
                y = [h[key] for h in history]
                # plt.ylim(ymin=0, ymax=3.0)
                plt.plot(x, y)
            plt.legend(legends)
            plt.savefig(os.path.join(save_dir, key + ".png"))


def cal_acc(out, label):
    batch_size = label.shape[0]
    pred = torch.argmax(out, dim=1)
    num_true = torch.nonzero(pred == label).shape[0]
    acc = num_true / batch_size
    return torch.tensor(acc)


class LrManager(optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, lrs):
        def f(epoch):
            rate = 1
            for k in sorted(lrs.keys()):
                if epoch >= k:
                    rate = lrs[k]
                else:
                    break
            return rate
        super(LrManager, self).__init__(optimizer, f)


def main(cifar=100, epochs=250, batches_show=100):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
        print("warning: CUDA is not available, using CPU instead")
    dataset_cls = CIFAR10 if cifar == 10 else CIFAR100
    dataset_train = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=True,
                                transform=Compose([RandomHorizontalFlip(), RandomRotation(15), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
    dataset_val = dataset_cls(root=f"data/{dataset_cls.__name__}/", download=True, train=False,
                              transform=Compose([ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
    loader_train = DataLoader(dataset_train, batch_size=128, shuffle=True)
    loader_val = DataLoader(dataset_val, batch_size=128, shuffle=True)
    model = resnet18(pretrained=False, num_classes=cifar).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
    lr_scheduler = LrManager(optimizer, {0: 1.0, 50: 0.3, 100: 0.1, 150: 0.03, 200: 0.01})
    criterion = nn.CrossEntropyLoss()

    history = []
    model.train()
    for epoch in range(epochs):
        print("-------------------  TRAINING  -------------------")
        loss_train = 0.0
        running_loss = 0.0
        acc_train = 0.0
        running_acc = 0.0
        for batch, data in enumerate(loader_train, 1):
            img, label = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            pred = model(img)
            loss = criterion(pred, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            loss_train += loss.item()
            acc = cal_acc(pred, label)
            running_acc += acc.item()
            acc_train += acc.item()

            if batch % batches_show == 0:
                print(f"epoch: {epoch}, batch: {batch}, loss: {running_loss/batches_show:.4f}, acc: {running_acc/batches_show:.4f}")
                running_loss = 0.0
                running_acc = 0.0
        loss_train = loss_train / batch
        acc_train = acc_train / batch
        lr_scheduler.step()

        print("------------------- EVALUATING -------------------")
        with torch.no_grad():
            running_acc = 0.0
            for batch, data in enumerate(loader_val, 1):
                img, label = data[0].to(device), data[1].to(device)
                pred = model(img)
                acc = cal_acc(pred, label)
                running_acc += acc.item()
            acc_val = running_acc / batch
            print(f"epoch: {epoch}, acc_val: {acc_val:.4f}")

        history.append({"epoch": epoch, "loss_train": loss_train, "acc_train": acc_train, "acc_val": acc_val})
    draw_loss_curve([history], legends=[f"resnet18-CIFAR{cifar}"], save_dir=f"history/resnet18-CIFAR{cifar}[{datetime.now()}]")


if __name__ == '__main__':
    main()

标签: pythonpytorch

解决方案


Resnet18 来自torchvision.models它的 ImageNet 实现。因为 ImageNet 的样本比 CIFAR10/100 (32x32) 大得多 (224x224),所以第一层旨在积极地对输入进行下采样(“stem 网络”)。这会导致在小型 CIFAR10/100 图像上丢失许多有价值的信息。

为了在 CIFAR10 上达到良好的准确性,作者使用了不同的网络结构,如原始论文中所述: https ://arxiv.org/pdf/1512.03385.pdf 并在本文中解释: https ://towardsdatascience.com/resnets-for-cifar -10-e63e900524e0

您可以从此 repo 下载 resnet fo CIFAR10:https ://github.com/akamaster/pytorch_resnet_cifar10


推荐阅读