首页 > 解决方案 > PyTorch:FasterRCNN/MaskRCNN 在不同 cuda 设备上的不同输出

问题描述

我正在尝试在 PyTorch 中进行对象检测的预训练 fast-rcnn 模型,并观察到在不同 cuda 设备上执行以下代码时出现奇怪的行为。

import io
import torch
import torchvision.transforms as transforms
from torchvision.models.detection.faster_rcnn import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from PIL import Image
from torch.autograd import Variable

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
backbone = resnet_fpn_backbone('resnet50', True)
model = FasterRCNN(backbone, num_classes=91)
state_dict = torch.load('/home/ubuntu/state_dicts/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth')
model.load_state_dict(state_dict)
model.to(device)
model.eval()


def pre_process(image_bytes):
    my_preprocess = transforms.Compose([transforms.ToTensor()])
    image = Image.open(io.BytesIO(image_bytes))
    image = my_preprocess(image)
    return image


def get_prediction(image_bytes, threshold=0.5):
    tensor = pre_process(image_bytes=image_bytes)
    tensor = Variable(tensor).to(device)
    pred = model([tensor])
    print(pred)


with open("/home/ubuntu/persons.jpg", 'rb') as f:
    image_bytes = f.read()
    get_prediction(image_bytes=image_bytes)

上面的代码为设备 cuda:0 和 cuda:1 返回不同的输出。有人可以帮助理解同一台机器上不同 cuda 设备上相同型号的输出差异。

在使用 cuda:0 时,它返回与在 CPU 上相同的响应。

在使用 MaskRCNN 模型时观察到相同的行为。

cuda 上的输出:0

[{'boxes': tensor([[167.4223,  57.0383, 301.3054, 436.6868],
        [ 89.6149,  64.8980, 191.4021, 446.6606],
        [362.3454, 161.9877, 515.5366, 385.2343],
        [ 67.3742, 277.6379, 111.6810, 400.2647],
        [228.7159, 145.8775, 303.5066, 231.1051],
        [379.4247, 259.9776, 419.0149, 317.9510],
        [517.9014, 149.5500, 636.5953, 365.5251],
        [268.9992, 217.2433, 423.9517, 390.4785],
        [539.6832, 157.8171, 616.1689, 253.0961],
        [477.1378, 147.9255, 611.0255, 297.9276],
        [286.6689, 216.3575, 550.4538, 383.1956],
        [627.4468, 177.1990, 640.0000, 247.3514],
        [ 88.3993, 226.4796, 560.9189, 421.6618],
        [406.9602, 261.8285, 453.7620, 357.5365],
        [451.3659, 207.4905, 504.6570, 287.6619],
        [454.3897, 207.9612, 487.7692, 270.3133],
        [451.8828, 208.3855, 631.0622, 355.3239],
        [497.1180, 289.9157, 581.5941, 356.1050],
        [600.6650, 183.4176, 621.5589, 250.3380],
        [559.7050, 202.6747, 608.1462, 250.1502],
        [375.3307, 245.6641, 444.8958, 333.0625],
        [453.1024, 210.8463, 553.8406, 296.7747],
        [555.2745, 199.9524, 611.2347, 250.5636],
        [359.7946, 219.5903, 425.5572, 316.5619],
        [476.7842, 249.0592, 583.8101, 354.6469],
        [ 71.4854, 333.2897, 108.0255, 399.1010],
        [207.6522, 121.4260, 301.1808, 251.5350],
        [550.4424, 175.4845, 621.4010, 317.4897],
        [445.1313, 209.7148, 519.7682, 331.3234],
        [523.6974, 193.5186, 548.5457, 234.6627],
        [449.0608, 229.3627, 572.3047, 293.8238],
        [348.8312, 185.0679, 620.9442, 368.1201],
        [578.4594, 232.6871, 586.2761, 246.6013],
        [359.9344, 166.1812, 502.6697, 287.2637],
        [ 43.1700, 244.8350, 407.5768, 394.7983],
        [115.0793, 126.5799, 177.2827, 198.4358],
        [476.8102, 147.0127, 566.3655, 260.0383],
        [410.9664, 258.0466, 514.5250, 357.0403],
        [450.8164, 277.2901, 521.0891, 359.8105],
        [ 63.9356, 221.3673, 126.4192, 409.7991],
        [625.5704, 189.2636, 640.0000, 256.4739],
        [  1.7555, 174.2491,  86.2912, 436.6681],
        [ 65.3964, 274.4007, 106.8389, 349.2521],
        [558.3841, 197.9385, 639.8632, 368.0412],
        [193.0894, 164.9078, 599.5771, 384.6865],
        [269.0641, 126.7004, 324.2201, 146.3630],
        [359.1832, 201.2081, 484.3798, 276.5368],
        [580.0465, 231.4633, 593.2866, 247.9024],
        [454.5699, 142.0131, 634.2507, 258.4456],
        [616.1375, 246.1040, 639.7282, 255.8053],
        [309.7035, 151.7276, 518.3733, 249.3150],
        [615.1505, 246.0356, 639.2537, 255.4936],
        [452.0419, 199.0634, 584.8884, 357.6918],
        [270.1078, 216.1271, 408.6000, 395.1962],
        [564.9176, 199.7667, 606.9827, 245.9028],
        [  1.7000, 279.6961,  92.9089, 393.7010],
        [495.4763, 253.3147, 640.0000, 361.1835],
        [452.0239, 208.3828, 502.1486, 285.4540],
        [554.9769, 214.0762, 601.4109, 248.5285],
        [473.0355, 251.5581, 575.2361, 298.9354],
        [383.1731, 259.1596, 418.4447, 312.5125],
        [265.9569, 143.7254, 640.0000, 311.1364],
        [353.1688, 200.4693, 494.6974, 272.1262],
        [229.8953, 142.8851, 254.5031, 226.0164]], device='cuda:0',
       grad_fn=<StackBackward>), 'labels': tensor([ 1,  1,  1, 31, 31, 31,  1, 15,  1,  1, 15,  1, 15, 31, 62, 62, 15, 15,
         1, 18, 31, 62,  1, 31, 15, 31, 31,  1, 31, 32, 15, 15, 77,  1, 15, 27,
         1, 31, 31, 31, 62, 64, 31,  1, 15, 15, 62, 77, 15, 15, 15, 67, 62, 62,
        27, 64, 15, 15, 31, 15, 44, 15, 15, 31], device='cuda:0'), 'scores': tensor([0.9995, 0.9995, 0.9978, 0.9925, 0.9922, 0.9896, 0.9828, 0.9582, 0.8994,
        0.8727, 0.8438, 0.8364, 0.7470, 0.7322, 0.6674, 0.5940, 0.4650, 0.3875,
        0.3826, 0.3792, 0.3722, 0.3720, 0.3480, 0.3407, 0.2381, 0.2210, 0.2163,
        0.2060, 0.1994, 0.1939, 0.1769, 0.1652, 0.1589, 0.1521, 0.1516, 0.1499,
        0.1495, 0.1419, 0.1248, 0.1184, 0.1124, 0.1098, 0.1077, 0.1059, 0.1035,
        0.0986, 0.0975, 0.0910, 0.0909, 0.0882, 0.0863, 0.0802, 0.0733, 0.0709,
        0.0699, 0.0668, 0.0662, 0.0651, 0.0600, 0.0586, 0.0578, 0.0578, 0.0577,
        0.0540], device='cuda:0', grad_fn=<IndexBackward>)}]

cuda:1(或除 cuda:0 以外的任何 GPU)上的输出

[{'boxes': tensor([[218.7705,   0.0000, 640.0000, 491.0000]], device='cuda:1',
       grad_fn=<StackBackward>), 'labels': tensor([77], device='cuda:1'), 'scores': tensor([0.0646], device='cuda:1', grad_fn=<IndexBackward>)}]

标签: pytorchobject-detectiontorchtorchvision

解决方案


推荐阅读