首页 > 解决方案 > 如何从 PyTorch 中的 Mask R-CNN 预测中为图像生成准确的掩码?

问题描述

我已经训练了一个 Mask RCNN 网络来对苹果进行实例分割。我能够加载权重并为我的测试图像生成预测。生成的蒙版似乎在正确的位置,但蒙版本身并没有真正的形式..它看起来就像一堆像素

训练是根据本文的数据集完成的,这里是用于训练和生成权重的代码的 github 链接

预测代码如下。(我省略了创建路径变量和分配路径的部分)

import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput

import torch
import torch.utils.data
import torchvision

from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import utility.utils as utils
import utility.transforms as T

from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline


def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def get_maskrcnn_model_instance(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

num_classes = 2
device = torch.device('cpu')

model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)

dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()

with torch.no_grad():
    prediction = model([img.to(device)])

prediction

Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

(unable to load image here since its over 2MB.  

Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

这是原始图像的Imgur 链接。下面是其中一个实例的预测掩码

一个实例的掩码输出

另外,您能否帮我理解下所示生成的预测矩阵的数据结构。我如何访问掩码以生成显示所有掩码的单个图像?

[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
          [1418.7872, 1467.0619, 1732.0828, 1796.1527],
          [1608.0396, 2064.6482, 1710.7534, 2206.5535],
          [2326.3750, 1690.3418, 2542.2112, 1883.2626],
          [2213.2024, 1864.3657, 2299.8933, 1963.0178],
          [1112.9083, 1732.5953, 1236.7600, 1823.0170],
          [1150.8256,  614.0334, 1218.8584,  711.4094],
          [ 942.7086,  794.6043, 1138.2318, 1008.0430],
          [1065.4371,  723.0493, 1192.7570,  870.3763],
          [1002.3103,  883.4616, 1146.9994, 1006.6841],
          [1315.2816, 1680.8625, 1531.3210, 1989.3317],
          [1244.5769, 1925.0903, 1459.5417, 2175.3252],
          [1725.2191, 2082.6187, 1934.0227, 2274.2952],
          [ 936.3065, 1554.3765, 1014.2722, 1659.4229],
          [ 934.8851, 1541.3331, 1090.4736, 1657.3751],
          [2486.0120,  776.4577, 2547.2329,  847.9725],
          [2336.1675,  698.6327, 2508.6492,  921.4550],
          [2368.4077, 1954.1102, 2448.4004, 2049.5796],
          [1899.1403, 1775.2371, 2035.7561, 1962.6923],
          [2176.0664, 1075.1553, 2398.6084, 1267.2555],
          [2274.8899,  641.6769, 2395.9634,  791.3353],
          [2535.1580,  874.4780, 2642.8213,  966.4614],
          [2183.4236,  619.9688, 2288.5676,  758.6825],
          [2183.9832, 1122.9382, 2334.9583, 1263.3226],
          [1135.7822,  779.0529, 1225.9871,  890.0135],
          [ 317.3954, 1328.6995,  397.3900, 1467.7740],
          [ 945.4811, 1833.3708,  997.2318, 1878.8607],
          [1992.4447,  679.4969, 2134.6667,  835.8701],
          [1098.5416, 1452.7799, 1429.1808, 1771.4460],
          [1657.3193, 1405.5405, 1781.6273, 1574.6780],
          [1443.8911, 1747.1544, 1739.0361, 2076.9724],
          [1092.6003, 1165.3340, 1206.0881, 1383.8314],
          [2466.4170, 1945.5931, 2555.1931, 2039.8368],
          [2561.8508, 1616.2659, 2672.1033, 1742.2332],
          [1894.4806,  907.9214, 2097.1875, 1182.6473],
          [2321.5005, 1701.3344, 2368.3699, 1865.3914],
          [2180.0781,  567.5969, 2344.6357,  763.4360],
          [1845.7612,  668.6808, 2045.2688,  899.8501],
          [1858.9216, 2145.7097, 1961.8870, 2273.5088],
          [ 261.4607, 1314.0154,  396.9288, 1486.9498],
          [2488.1682, 1585.2357, 2669.0178, 1794.9926],
          [2696.9548,  936.0087, 2802.7961, 1025.2294],
          [1593.6837, 1489.8641, 1720.3124, 1627.8135],
          [2517.9468,  857.1713, 2567.1125,  929.4335],
          [1943.2167,  636.3422, 2151.4419,  853.8924],
          [2143.5664, 1100.0521, 2308.1570, 1290.7125],
          [2140.9231, 1947.9692, 2238.6956, 2000.6249],
          [1461.6316, 2105.2593, 1559.7675, 2189.0264],
          [2114.0781,  374.8153, 2222.8838,  559.9851],
          [2350.5320,  726.5779, 2466.8140,  878.2617]]),
  'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1]),
  'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
          0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
          0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
          0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
          0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
          0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
  'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          ...,


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]]])}]

标签: pythonmachine-learningcomputer-visionpytorchtorchvision

解决方案


Mask R-CNN的预测具有以下结构:

在推理过程中,模型只需要输入张量,并将后处理的预测返回为 a List[Dict[Tensor]],每个输入图像一个。的字段Dict如下:

boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W  
labels (Int64Tensor[N]): the predicted labels for each image  
scores (Tensor[N]): the scores or each prediction  
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.

您可以使用 OpenCVfindContoursdrawContours函数来绘制蒙版,如下所示:

img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)

for i in range(len(prediction[0]['masks'])):
    # iterate over masks
    mask = prediction[0]['masks'][i, 0]
    mask = mask.mul(255).byte().cpu().numpy()
    contours, _ = cv2.findContours(
            mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
    cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)

cv2.imshow('img output', img_cv)

样本输出:

样本输出


推荐阅读