首页 > 解决方案 > 使用plt显示pytorch图片

问题描述

x是图像,y是标签,metadata是日期、时间等。

for x, y_true, metadata in train_loader:
  print(x.shape)

形状返回:

torch.Size([16, 3, 448, 448])

如何将 x 显示为图像?我用plt吗?

标签: pythonmatplotlibpytorch

解决方案


x的不是单个图像,而是一组 16 个不同的图像,大小均为 448x448 像素。

您可以使用torchvision.utils.make_grid转换x为 4x4 图像的网格,然后绘制它:

import torchvision

with torch.no_grad():  # no need for gradients here
  grid = torchvision.utils.make_grid(x, nrow=4)  # you might consider normalize=True 
  # convert the grid into a numpy array suitable for plt
  grid_np = grid.cpu().numpy().transpose(1, 2, 0)  # channel dim should be last
  plt.matshow(grid_np)

推荐阅读