首页 > 解决方案 > torchvision.datasets.MNIST 给出 TypeError: expected np.ndarray (got numpy.ndarray)

问题描述

我正在学习 pytorch 和 torchvision 以及远程连接 jupyter notebook 中的示例。但是当我尝试运行一个程序时,有个问题一直困扰着我。
(火炬。版本)1.7.1(火炬视觉。版本):0.8.2

import torch
import torchvision
from torchvision import datasets,transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

transform = transforms.Compose([
     transforms.ToTensor(),
      transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
 ])   


data_train=datasets.MNIST(root="./data/",\
                          transform=transform,\
                          train=True,download=True)
data_test=datasets.MNIST(root="./data/",\
                          transform=transform,\
                          train=False,download=True)
    
    
data_loader_train=torch.utils.data.DataLoader\
    (dataset=data_train,batch_size=64,shuffle=True)
    
data_loader_test=torch.utils.data.DataLoader\
    (dataset=data_test,batch_size=64,shuffle=True)
    
    
images,labels=next(iter(data_loader_train))
img=torchvision.utils.make_grid(images)

img=img.numpy().transpose(1,2,0)
std=[0.5,0.5,0.5]
mean=[0.5,0.5,0.5]
img=img*std+mean
print([labels[i] for i in range(64)])
plt.imshow(img)

但输出是

runfile('C:/Users/ao/.spyder-py3/write1.py', wdir='C:/Users/ao/.spyder-py3')
Using downloaded and verified file: ./data/MNIST\raw\train-images-idx3-ubyte.gz
Extracting ./data/MNIST\raw\train-images-idx3-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./data/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw
Processing...
Traceback (most recent call last):

  File "C:\Users\ao\.spyder-py3\write1.py", line 20, in <module>
    train=True,download=True)

  File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 79, in __init__
    self.download()

  File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 152, in download
    read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),

  File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 493, in read_image_file
    x = read_sn3_pascalvincent_tensor(f, strict=False)

  File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 480, in read_sn3_pascalvincent_tensor
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

TypeError: expected np.ndarray (got numpy.ndarray)

标签: pythonpytorchtypeerrormnisttorchvision

解决方案


推荐阅读