python - torchvision.datasets.MNIST 给出 TypeError: expected np.ndarray (got numpy.ndarray)
问题描述
我正在学习 pytorch 和 torchvision 以及远程连接 jupyter notebook 中的示例。但是当我尝试运行一个程序时,有个问题一直困扰着我。
(火炬。版本)1.7.1(火炬视觉。版本):0.8.2
import torch
import torchvision
from torchvision import datasets,transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
data_train=datasets.MNIST(root="./data/",\
transform=transform,\
train=True,download=True)
data_test=datasets.MNIST(root="./data/",\
transform=transform,\
train=False,download=True)
data_loader_train=torch.utils.data.DataLoader\
(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader\
(dataset=data_test,batch_size=64,shuffle=True)
images,labels=next(iter(data_loader_train))
img=torchvision.utils.make_grid(images)
img=img.numpy().transpose(1,2,0)
std=[0.5,0.5,0.5]
mean=[0.5,0.5,0.5]
img=img*std+mean
print([labels[i] for i in range(64)])
plt.imshow(img)
但输出是
runfile('C:/Users/ao/.spyder-py3/write1.py', wdir='C:/Users/ao/.spyder-py3')
Using downloaded and verified file: ./data/MNIST\raw\train-images-idx3-ubyte.gz
Extracting ./data/MNIST\raw\train-images-idx3-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./data/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/MNIST\raw
Using downloaded and verified file: ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/MNIST\raw
Processing...
Traceback (most recent call last):
File "C:\Users\ao\.spyder-py3\write1.py", line 20, in <module>
train=True,download=True)
File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 79, in __init__
self.download()
File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 152, in download
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 493, in read_image_file
x = read_sn3_pascalvincent_tensor(f, strict=False)
File "D:\ProgramData\Anaconda3\envs\apple\lib\site-packages\torchvision\datasets\mnist.py", line 480, in read_sn3_pascalvincent_tensor
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
TypeError: expected np.ndarray (got numpy.ndarray)
解决方案
推荐阅读
- php - Docker不会更新php版本
- javascript - 如何仅将字符串数字转换为数字?
- linux - How do I write a file path which include a regular expression
- ios - UITableview to always take full display height
- google-apps-script - CLASP 本地登录
- f# - 具有可变参数的函数
- django - Django - 不接收发送的 AJAX 数据
- kotlin - Mapbox:错误设置属性:icon-image [2] 分支标签必须是数字或字符串
- javascript - 使用 puppeteer 连接时访问页面时出现问题
- jdbc - Mockito 在 Coverity 中导致误报?