首页 > 解决方案 > 如何优化用于 TensorRT 推理的 grid_sample 的自定义双线性采样替代方案?

问题描述

我试图通过 ONNX(opset 11)将模型从 Pytorch(1.6)转换为 TensorRT(7)的 torch.nn.functional.grid_sample。Opset 11 不支持 grid_sample 转换。我发现的自定义替代方案(https://github.com/pytorch/pytorch/issues/27212)在 Pytorch 中运行时非常慢,并且在将主循环转换为 TRT 时存在问题。

我自己的双线性采样实现(不仅仅是 grid_sample,而是整个原始采样,基于 grid_sample)在 Pytorch 中执行得更快,并成功转换为 TRT。但是我在 TRT 中的自定义双线性采样比 Pytorch 中的慢(5.6 ms vs 2.0 ms)。事实证明,Pytorch image[:, ind, y0, x0] 索引生成 Gather 层的运行时间约为 0.97 ms。在这种双线性采样的 TRT 版本中有 4 个这样的层。

所以问题是:

这是双线性采样函数的代码:

def bilinear_sample_noloop(image, grid):
    """
    :param image: sampling source of shape [N, C, H, W]
    :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
    :return: sampling result of shape [N, C, grid_H, grid_W]
    """
    Nt, C, H, W = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid, ygrid = grid.split([1, 1], dim=-1)
    mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
    x0 = torch.floor(xgrid)
    x1 = x0 + 1
    y0 = torch.floor(ygrid)
    y1 = y0 + 1
    wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
    wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
    x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
    y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
    x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
    y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
    ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
    ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
    image = image.permute(1, 0, 2, 3)
    output_tensor = (image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + \
                 image[:, ind, y1, x1] * wd).permute(1, 0, 2, 3)
    output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
    image = image.permute(1, 0, 2, 3)
    return output_tensor, mask

时间分析参数:

使用 trtexec 进行 TRT 模型分析的一部分:

     Layer   Time (ms)   Avg. Time (ms)   Time %
...
   Mul_146        5.82             0.03      0.5
   Add_147        8.50             0.04      0.7
Gather_148      214.39             0.97     17.3
Gather_174      214.25             0.97     17.3
Gather_201      213.88             0.97     17.3
Gather_228      214.48             0.97     17.3
 Add_237))       25.01             0.11      2.0
   Mul_251        7.84             0.04      0.6
     Total     1238.40             5.60    100.0

此外,我尝试将图像视为除 C 之外的所有维度上的线性数组,并创建线性索引以寻址图像 [:, p0] 形式的元素。在这种情况下,Gather 变得更慢(大约 1.07 毫秒)。然后我考虑了 C=1(因为它总是在原始模型中)并将张量元素处理为 image[p0]。这次 Gather 大约需要 0.92 毫秒(仍然太慢)。

标签: pytorchonnxtensorrtbilinear-interpolationgather

解决方案


以下代码可用于将 Pytorch 转换为 TensorRT 作为图像的 bilinear_interpolate

def bilinear_interpolate_torch(im, y, x):
'''
   im : B,C,H,W
   y : 1,numPoints -- pixel location y float
   x : 1,numPOints -- pixel location y float
'''


x0 = torch.floor(x).type(torch.cuda.LongTensor)
x1 = x0 + 1

y0 = torch.floor(y).type(torch.cuda.LongTensor)
y1 = y0 + 1

wa = (x1.type(torch.cuda.FloatTensor) - x) * (y1.type(torch.cuda.FloatTensor) - y)
wb = (x1.type(torch.cuda.FloatTensor) - x) * (y - y0.type(dtype))
wc = (x - x0.type(torch.cuda.FloatTensor)) * (y1.type(torch.cuda.FloatTensor) - y)
wd = (x - x0.type(torch.cuda.FloatTensor)) * (y - y0.type(torch.cuda.FloatTensor))
# Instead of clamp
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]

return Ia  * wa + Ib * wb + Ic * wc + Id * wd

推荐阅读