python - 如何从 PyTorch 中的 Mask R-CNN 预测中为图像生成准确的掩码?
问题描述
我已经训练了一个 Mask RCNN 网络来对苹果进行实例分割。我能够加载权重并为我的测试图像生成预测。生成的蒙版似乎在正确的位置,但蒙版本身并没有真正的形式..它看起来就像一堆像素
训练是根据本文的数据集完成的,这里是用于训练和生成权重的代码的 github 链接
预测代码如下。(我省略了创建路径变量和分配路径的部分)
import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput
import torch
import torch.utils.data
import torchvision
from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import utility.utils as utils
import utility.transforms as T
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def get_maskrcnn_model_instance(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
num_classes = 2
device = torch.device('cpu')
model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)
dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()
with torch.no_grad():
prediction = model([img.to(device)])
prediction
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
(unable to load image here since its over 2MB.
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())
这是原始图像的Imgur 链接。下面是其中一个实例的预测掩码
另外,您能否帮我理解下所示生成的预测矩阵的数据结构。我如何访问掩码以生成显示所有掩码的单个图像?
[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
[1418.7872, 1467.0619, 1732.0828, 1796.1527],
[1608.0396, 2064.6482, 1710.7534, 2206.5535],
[2326.3750, 1690.3418, 2542.2112, 1883.2626],
[2213.2024, 1864.3657, 2299.8933, 1963.0178],
[1112.9083, 1732.5953, 1236.7600, 1823.0170],
[1150.8256, 614.0334, 1218.8584, 711.4094],
[ 942.7086, 794.6043, 1138.2318, 1008.0430],
[1065.4371, 723.0493, 1192.7570, 870.3763],
[1002.3103, 883.4616, 1146.9994, 1006.6841],
[1315.2816, 1680.8625, 1531.3210, 1989.3317],
[1244.5769, 1925.0903, 1459.5417, 2175.3252],
[1725.2191, 2082.6187, 1934.0227, 2274.2952],
[ 936.3065, 1554.3765, 1014.2722, 1659.4229],
[ 934.8851, 1541.3331, 1090.4736, 1657.3751],
[2486.0120, 776.4577, 2547.2329, 847.9725],
[2336.1675, 698.6327, 2508.6492, 921.4550],
[2368.4077, 1954.1102, 2448.4004, 2049.5796],
[1899.1403, 1775.2371, 2035.7561, 1962.6923],
[2176.0664, 1075.1553, 2398.6084, 1267.2555],
[2274.8899, 641.6769, 2395.9634, 791.3353],
[2535.1580, 874.4780, 2642.8213, 966.4614],
[2183.4236, 619.9688, 2288.5676, 758.6825],
[2183.9832, 1122.9382, 2334.9583, 1263.3226],
[1135.7822, 779.0529, 1225.9871, 890.0135],
[ 317.3954, 1328.6995, 397.3900, 1467.7740],
[ 945.4811, 1833.3708, 997.2318, 1878.8607],
[1992.4447, 679.4969, 2134.6667, 835.8701],
[1098.5416, 1452.7799, 1429.1808, 1771.4460],
[1657.3193, 1405.5405, 1781.6273, 1574.6780],
[1443.8911, 1747.1544, 1739.0361, 2076.9724],
[1092.6003, 1165.3340, 1206.0881, 1383.8314],
[2466.4170, 1945.5931, 2555.1931, 2039.8368],
[2561.8508, 1616.2659, 2672.1033, 1742.2332],
[1894.4806, 907.9214, 2097.1875, 1182.6473],
[2321.5005, 1701.3344, 2368.3699, 1865.3914],
[2180.0781, 567.5969, 2344.6357, 763.4360],
[1845.7612, 668.6808, 2045.2688, 899.8501],
[1858.9216, 2145.7097, 1961.8870, 2273.5088],
[ 261.4607, 1314.0154, 396.9288, 1486.9498],
[2488.1682, 1585.2357, 2669.0178, 1794.9926],
[2696.9548, 936.0087, 2802.7961, 1025.2294],
[1593.6837, 1489.8641, 1720.3124, 1627.8135],
[2517.9468, 857.1713, 2567.1125, 929.4335],
[1943.2167, 636.3422, 2151.4419, 853.8924],
[2143.5664, 1100.0521, 2308.1570, 1290.7125],
[2140.9231, 1947.9692, 2238.6956, 2000.6249],
[1461.6316, 2105.2593, 1559.7675, 2189.0264],
[2114.0781, 374.8153, 2222.8838, 559.9851],
[2350.5320, 726.5779, 2466.8140, 878.2617]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]),
'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]])}]
解决方案
Mask R-CNN的预测具有以下结构:
在推理过程中,模型只需要输入张量,并将后处理的预测返回为 a
List[Dict[Tensor]]
,每个输入图像一个。的字段Dict
如下:
boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W
labels (Int64Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.
您可以使用 OpenCVfindContours
和drawContours
函数来绘制蒙版,如下所示:
img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)
for i in range(len(prediction[0]['masks'])):
# iterate over masks
mask = prediction[0]['masks'][i, 0]
mask = mask.mul(255).byte().cpu().numpy()
contours, _ = cv2.findContours(
mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)
cv2.imshow('img output', img_cv)
样本输出:
推荐阅读
- eloqua - 向 Eloqua Marketing Cloud 中的自定义对象中关联多个记录的同一电子邮件地址发送电子邮件
- ionic-framework - Ionic 版本 - 了解 ionic cli 版本和 ionic core 版本之间的区别
- mysql - MySQL:寻求帮助来定义查询
- java - 如何使用原始 SQL 查询在 Moqui 中查找实体?
- c - C 中的 pow 函数未按预期工作
- python - Python中的函数参数转换
- macos - Macos 应用程序在使用 NSDistributedNotificationCenter 和 CFRunLoopAddSource 时挂起
- mysql - Postgres 表大小如何大于 Mysql 表大小?
- c# - 为什么“new int[n] is object[]”是假的?为什么“int[] is object[] arr”是一个模式错误?
- xpath - 这个 xpath 的确切含义