首页 > 解决方案 > PyTorch 在使用 Dataloader 加载时平铺图像

问题描述

我正在尝试使用 PyTorch 数据加载器加载图像数据集,但是生成的转换是平铺的,并且没有像我期望的那样将原始图像裁剪到中心。

transform = transforms.Compose([transforms.Resize(224),
                             transforms.CenterCrop(224),
                             transforms.ToTensor()])

dataset = datasets.ImageFolder('ml-models/downloads/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


images, labels = next(iter(dataloader))
import matplotlib.pyplot as plt
plt.imshow(images[6].reshape(224, 224, 3))

生成的图像是平铺的,而不是中心裁剪的。[![如此处的 Jupyter 快照所示][1]][1]

提供的转换有问题吗?(链接如下图所示:)[1]:https ://i.stack.imgur.com/HtrIa.png

标签: image-processingpytorchtorchvision

解决方案


Pytorch 以通道优先格式存储张量,因此 3 通道图像是形状为 (3, H, W) 的张量。Matplotlib 期望数据采用通道最后格式,即 (H, W, 3)。重塑不会重新排列尺寸,因为您需要Tensor.permute

plt.imshow(images[6].permute(1, 2, 0))

推荐阅读