首页 > 解决方案 > 带有 CIFAR-100 的 pytorch

问题描述

import torch
import torchvision
import torchvision.transforms as transforms

transform=transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5),(0.5))])

trainset=torchvision.datasets.CIFAR100(root='./dataset',train=True,
                                  download=True,transform=transform)

trainloader=torch.utils.data.CIFAR100(trainset,batch_size=4,shuffle=True)

testset=torchvision.datasets.CIFAR100(root='./dataset',train=False,
                                 download=True,transform=transform)

testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False)

classes=('0','1','2','3','4','5','6','7','8','9')

如下模块“torch.utils.data”编写的错误消息没有属性“CIFAR100”

当我将 torch.utils.data 与 cifar-10 一起使用时,它可以工作,但它不能与 cifar-100 一起工作,你能告诉我为什么会这样吗?

标签: deep-learningpytorch

解决方案


您的 trainloader 行中有错误,您必须将 trainset 传递给torch.utils.data.DataLoader. 将此行替换为,

trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True)

推荐阅读