首页 > 解决方案 > 将 PyTorch U-Net 模型减少到不到一半的参数显示几乎没有内存减少

问题描述

我想知道是否还有其他人在 Pytorch 中使用 U-Net 并且在内存使用方面遇到问题?我尝试了一些方法来减少我正在使用的内存,但我遇到了一个相当奇怪的问题。

背景 - 我正在尝试训练多类 u-net 以与本文非常相似的方式从 3D 图像中预测 5 个类 - https://arxiv.org/pdf/1606.06650.pdf。我的图像大小为 320x150x26 像素。我使用的是半精度和 2 的批量大小。整个过程需要很长时间,我想增加批量大小。

我的第一个 U-Net 架构如下所示:

class NetULarge(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 512, dtype=dtype)
        self.down4 = mp.Down(512, 512, dtype=dtype)
        self.up1 = mp.Up(1024, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(512, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(256, 64, self.bilinear, dtype=dtype)
        self.up4 = mp.Up(128, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

我决定使用以下代码查看它的参数:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  

我得到 40158853 - 这比有问题的论文要多得多,所以我想我会把模型的大小缩小到以下:

class NetU(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 256, dtype=dtype)
        self.up1 = mp.Up(512, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(384, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(192, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
     
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        out = self.outc(x)
        return out

现在我得到 14496517 个参数 - 显着减少。

然而,这并没有减少任何合理的内存量。我首先使用 nvidia-smi 来观察我的 GPU 上的内存使用情况。两个网络的峰值都在 7000MiB 左右。有时,两者都会下降到大约 1500MiB 左右,但它们之间几乎没有区别。

我决定再深入一点,并在网上找到了一些代码,这些代码使用前向和后向挂钩来列出模型的每个部分正在使用的内容。我不会在这里列出它们,因为痕迹很长,但我确实注意到了一些有趣的东西。

在较小的网络中

      layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    87474176   106954752
1             1         1   DoubleConv    0       pre    87474176   106954752
2             2         2   Sequential    0       pre    87474176   106954752
3             3         3       Conv3d    0       pre    87474176   106954752
4             3         4       Conv3d    0       fwd   406962176   427819008
.....
670          91       670  BatchNorm3d    0       bwd  4954478592  5471469568

在更大的网络中

          layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    34027520    46137344
1             1         1   DoubleConv    0       pre    34027520    46137344
2             2         2   Sequential    0       pre    34027520    46137344
3             3         3       Conv3d    0       pre    34027520    46137344
4             3         4       Conv3d    0       fwd   353515520   367001600
....
966          62       966           Up    0       bwd  5632503808  5937037312

较大的网络需要执行更多步骤,但尽管网络之间使用的整体内存有所减少,但两者的峰值水平相同。但是,在较小的网络情况下,缓存内存仍然很高,即使使用的总内存非常低,例如:

1563         45      1563   DoubleConv    0       bwd  3151531008  6121586688

这个缓存大小似乎接近 7000MiB 左右,我用 nvidia-smi 看到

我不确定如何最好地解决这个问题。我知道反向传递有很多工作要做,但是为什么缓存应该这么高,为什么我没有看到更大的减少,因为模型要小得多?干杯乙

标签: pythonmemorypytorch

解决方案


推荐阅读