首页 > 解决方案 > 将 PyTorch U-Net 模型减少到不到一半的参数显示几乎没有内存减少


我想知道是否还有其他人在 Pytorch 中使用 U-Net 并且在内存使用方面遇到问题?我尝试了一些方法来减少我正在使用的内存,但我遇到了一个相当奇怪的问题。

背景 - 我正在尝试训练多类 u-net 以与本文非常相似的方式从 3D 图像中预测 5 个类 - https://arxiv.org/pdf/1606.06650.pdf。我的图像大小为 320x150x26 像素。我使用的是半精度和 2 的批量大小。整个过程需要很长时间,我想增加批量大小。

我的第一个 U-Net 架构如下所示:

class NetULarge(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 512, dtype=dtype)
        self.down4 = mp.Down(512, 512, dtype=dtype)
        self.up1 = mp.Up(1024, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(512, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(256, 64, self.bilinear, dtype=dtype)
        self.up4 = mp.Up(128, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out


pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

我得到 40158853 - 这比有问题的论文要多得多,所以我想我会把模型的大小缩小到以下:

class NetU(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 256, dtype=dtype)
        self.up1 = mp.Up(512, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(384, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(192, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        out = self.outc(x)
        return out

现在我得到 14496517 个参数 - 显着减少。

然而,这并没有减少任何合理的内存量。我首先使用 nvidia-smi 来观察我的 GPU 上的内存使用情况。两个网络的峰值都在 7000MiB 左右。有时,两者都会下降到大约 1500MiB 左右,但它们之间几乎没有区别。



      layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    87474176   106954752
1             1         1   DoubleConv    0       pre    87474176   106954752
2             2         2   Sequential    0       pre    87474176   106954752
3             3         3       Conv3d    0       pre    87474176   106954752
4             3         4       Conv3d    0       fwd   406962176   427819008
670          91       670  BatchNorm3d    0       bwd  4954478592  5471469568


          layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    34027520    46137344
1             1         1   DoubleConv    0       pre    34027520    46137344
2             2         2   Sequential    0       pre    34027520    46137344
3             3         3       Conv3d    0       pre    34027520    46137344
4             3         4       Conv3d    0       fwd   353515520   367001600
966          62       966           Up    0       bwd  5632503808  5937037312


1563         45      1563   DoubleConv    0       bwd  3151531008  6121586688

这个缓存大小似乎接近 7000MiB 左右,我用 nvidia-smi 看到


标签: pythonmemorypytorch

