python - 使用plt显示pytorch图片
问题描述
x
是图像,y
是标签,metadata
是日期、时间等。
for x, y_true, metadata in train_loader:
print(x.shape)
形状返回:
torch.Size([16, 3, 448, 448])
如何将 x 显示为图像?我用plt吗?
解决方案
您x
的不是单个图像,而是一组 16 个不同的图像,大小均为 448x448 像素。
您可以使用torchvision.utils.make_grid
转换x
为 4x4 图像的网格,然后绘制它:
import torchvision
with torch.no_grad(): # no need for gradients here
grid = torchvision.utils.make_grid(x, nrow=4) # you might consider normalize=True
# convert the grid into a numpy array suitable for plt
grid_np = grid.cpu().numpy().transpose(1, 2, 0) # channel dim should be last
plt.matshow(grid_np)
推荐阅读
- python-3.x - Geting list in python using map function
- c# - Dynamically formatting Text items in a ListBox in XAML
- sendgrid - 获取 Sendgrid 取消订阅首选项的 URL
- javascript - 如果其小部件不可见,如何隐藏 Blogger 部分
- php - 使用 POST 重定向数组
- postgresql - 让 Rails 或 Postges 进行排序会更快吗?
- google-cloud-platform - 谷歌云 OAuth 验证视频被拒绝,要求澄清,没有答案
- node.js - GraphQL 将输入参数传递给订阅
- regex - Google 表单 MMM.YYYY 中的正则表达式验证
- javascript - 如何在反应应用程序中存储用户特定数据