首页 > 解决方案 > torch.nn.fucntional.interpolate():参数设置

问题描述

我正在使用torch.nn.functional.interpolate()调整图像大小。

首先,我使用transforms.ToTensor()将图像转换为张量,其大小为 (3, 252, 252),(252, 252) 是导入图像的大小。我想要做的是用interpolate()函数创建一个大小为 (3, 504, 504) 的张量。

我设置了 para scale_factor=2,但它返回了 (3, 252, 504) 张量。然后我设置它scale_factor=(1,2,2)并收到这样的尺寸冲突错误:

size shape must match input shape. Input is 1D, size is 3

那么我应该如何设置参数以接收 (3, 504, 504) 张量?

标签: pythonimage-processingpytorchinterpolation

解决方案


如果您正在使用scale_factor,则需要提供一批图像而不是单个图像。所以你需要通过使用添加一批unsqueeze(0)然后赋予它interpolate如下功能:

import torch
import torch.nn.functional as F

img = torch.randn(3, 252, 252)  # torch.Size([3, 252, 252])
img = img.unsqueeze(0)  # torch.Size([1, 3, 252, 252])

out = F.interpolate(img, scale_factor=(2, 2), mode='nearest')
print(out.size()) # torch.Size([1, 3, 504, 504])

推荐阅读