python - Matplotlib Plot 将图像显示为单独的 rgb 图像而不是单个 RGB 图像
问题描述
我有一个表示为形状 [224, 224, 3] 的 Numpy 数组的图像。我正在尝试使用 matplotlib 绘制这个: plt.imshow(img) 但它不是获取单个 RGB 图像,而是在单个图中绘制分离的 R、G 和 B 图像。我哪里错了?
我尝试查看图像的形状以及绘制图像的一些示例。“img”变量的形状为 [224, 224, 3],是一个 numpy 数组类型。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# Define Image Transform
transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
# Load Custom Image Dataset
dataset = datasets.ImageFolder(root="./Cat_Dog_data",
transform=transform)
# DataLoader
dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
# Get one batch of Data
# len(images): 32
# len(labels): 32
# shape of images[0]: torch.Size([3, 224, 224])
images, labels = next(iter(dataLoader))
# img.shape: [224,224,3]
img = images[0].numpy().reshape([224, 224, 3])
plt.imshow(img)
plt.show()
解决方案
该代码最终使用 np.transpose() 函数而不是 np.reshape() 函数工作。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# Define Image Transform
transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
# Load Custom Image Dataset
dataset = datasets.ImageFolder(root="./Cat_Dog_data", transform=transform)
# DataLoader is a Generator
dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
# Get one batch of Data
images, labels = next(iter(dataLoader))
# Use transpose instead of reshape.
img = images[0].numpy().transpose((1, 2, 0))
plt.imshow(img)
plt.show()