python - Pytorch 中 Faster-RCNN 模型的输入图像大小
问题描述
我正在尝试用 Pytorch 实现 Faster-RCNN 模型。在结构中,模型的第一个元素是变换。
from torchvision.models.detection import fasterrcnn_resnet50_fpn
model = fasterrcnn_resnet50_fpn(pretrained=True)
print(model.transform)
GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
当图像通过 时Resize()
,它们以(800,h)
或(w, 1333)
根据宽度和高度的比例出现。
for i in range(2):
_, image, target = testset.__getitem__(i)
img = image.unsqueeze(0)
output, _ = model.transform(img)
Before Transform : torch.Size([512, 640])
After Transform : [(800, 1000)]
Before Transform : torch.Size([315, 640])
After Transform : [(656, 1333)]
我的问题是如何获得那些调整大小的输出以及他们为什么使用这种方法?我在论文中找不到信息,也无法理解关于transform in 的源代码fasterrcnn_resnet50_fpn
。
对不起我的英语不好
解决方案
对输入执行数据转换以输入模型
min_size:在将其馈送到主干之前要重新缩放的图像的最小尺寸。
max_size:在将其馈送到主干之前要重新缩放的图像的最大尺寸
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py#L256
我也找不到为什么它被概括为最小 800 和最大 1333,也没有在研究论文中找到任何东西。
但由于第一层是 Conv 层,网络的输入是固定大小的,我应用了许多其他增强功能,例如镜像、随机裁剪等,灵感来自基于 SSD 的网络。因此,我宁愿在一个单独的地方进行一次所有扩充,而不是两次。我认为该模型应该在使用形状和其他属性尽可能接近训练数据的图像进行验证期间发挥最佳效果。
尽管您可以尝试自定义 min_size 和 max_size ...
`
from .transform import GeneralizedRCNNTransform
min_size = 900 #changed from default
max_size = 1433 #changed from default
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
model = fasterrcnn_resnet50_fpn(pretrained=True, min_size, max_size, image_mean, image_std)
#batch of 4 image, 4 bboxes
images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
labels = torch.randint(1, 91, (4, 11))
images = list(image for image in images)
targets = []
for i in range(len(images)):
d = {}
d['boxes'] = boxes[i]
d['labels'] = labels[i]
targets.append(d)
output = model(images, targets)
`
或者您可以完全编写您的转换 https://pytorch.org/vision/stable/transforms.html
'
from torchvision.transforms import transforms as T
model = fasterrcnn_resnet50_rpn()
model.transform = T.Compose([*check torchvision.transforms for more*])
'
希望这可以帮助。
推荐阅读
- matplotlib - pcolormesh'GeoAxesSubplot'对象的cartopy问题没有属性'_hold'
- javascript - 从标签顺序中删除 HTML 标记
- duplicates - 显示重复的整行
- javascript - 如何在表头添加按钮?(反应原生)
- angular - Angular Uncaught TypeError: e is not a constructor after build --prod (work on ng serve)
- php - Google Speech to Text 与 Asterisk 实时通话的集成
- sql - SQL Server 浮点数
- mysql - 如何删除由同一表中的数据指定的 mySQL 表中的行(在一个表达式中)?
- optimization - 差流依赖和反依赖
- python-3.x - 从虚拟特征中提取分类数据