首页 > 解决方案 > 如何从 pytorch 网络中删除 maxpooling 层

问题描述

我在 pytorch v3.1 上有带有 resnet 编码器的“简单”Unet,它工作得很好:

class UNetResNet(nn.Module):
    """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.

    UNet: https://arxiv.org/abs/1505.04597
    ResNet: https://arxiv.org/abs/1512.03385

    Args:
            encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - ResNet encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.

    """

    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152:
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')

        self.pool = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1

        self.conv3 = self.encoder.layer2

        self.conv4 = self.encoder.layer3

        self.conv5 = self.encoder.layer4

        self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
                                   is_deconv)
        self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
                                   is_deconv)
        self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        pool = self.pool(conv5) # that pool layer I would like to delete
        center = self.center(pool)


        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

通过删除该池层,forward我们来到下一个forward函数:

def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        center = self.center(conv5) #now center connects with conv5

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        return self.final(F.dropout2d(dec0, p=self.dropout_2d))

仅应用此类更改,我们将面临此类错误:

~/Desktop/ml/salt/unet_models.py in forward(self, x)
    397         center = self.center(conv5)
    398 
--> 399         dec5 = self.dec5(torch.cat([center, conv5], 1))
    400 
    401         dec4 = self.dec4(torch.cat([dec5, conv4], 1))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 4 and 8 in dimension 2 at /pytorch/torch/lib/THC/generic/THCTensorMath.cu:111

所以我们必须以某种方式改变 dec4,但是怎么做呢?

更多信息:

    class DecoderBlockV2(nn.Module):
        def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
            super(DecoderBlockV2, self).__init__()
            self.in_channels = in_channels

            if is_deconv:
                """
                    Paramaters for Deconvolution were chosen to avoid artifacts, following
                    link https://distill.pub/2016/deconv-checkerboard/
                """

                self.block = nn.Sequential(
                    ConvRelu(in_channels, middle_channels),
                    nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                       padding=1),
            nn.BatchNorm2d(out_channels), 
                    nn.ReLU(inplace=True)
                )
            else:
                self.block = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear'),
                    ConvRelu(in_channels, middle_channels),
                    ConvRelu(middle_channels, out_channels),
                )

        def forward(self, x):
            return self.block(x)

class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x

标签: pythonpython-3.xpytorch

解决方案


您的错误源于您尝试连接 ( torch.cat)centerconv5- 但这些张量的尺寸不匹配。
最初,您有以下空间维度

  • conv5: 4x4
  • pool: 2x2
  • center: 4x4 # upsampled 这样你就可以 concatcenter并且conv5因为它们都是 4x4 张量。
    但是,如果您删除池化层,您最终会得到

  • conv5: 4x4

  • center: 8x8 # upsampling more than you need 现在你不能连接centerconv5因为它们的空间尺寸不匹配。

你能做什么?

  1. 一种选择是删除center以及池化,让您

    dec5 = self.dec5(conv5)
    
  2. 另一种选择是nn.Upsample从 的块中删除部分center,留下你

    self.block = nn.Sequential(
                    ConvRelu(in_channels, middle_channels),
                    ConvRelu(middle_channels, out_channels),
                )
    

推荐阅读