首页 > 解决方案 > 修改 inception-v3 的输入通道数时出错

问题描述

我想自定义 inception_v3 以使其适用于 4 通道输入。我尝试如下修改 inception v3 的第一层。

x=torch.randn((5,4,299,299))

model_ft=models.inception_v3(pretrained=True)
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
print(x.shape)
print(model_ft.Conv2d_1a_3x3.conv)
out=model_ft(x)

但它会产生以下错误。我认为输入形状和网络已正确修改,所以我不明白为什么会出错。有人有建议吗?

torch.Size([5, 4, 299, 299])
Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)

RuntimeErrorTraceback (most recent call last)
<ipython-input-118-41c045338348> in <module>
     29 print(model_ft.Conv2d_1a_3x3.conv)
     30 
---> 31 out=model_ft(x)
     32 print(out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    202     def forward(self, x: Tensor) -> InceptionOutputs:
    203         x = self._transform_input(x)
--> 204         x, aux = self._forward(x)
    205         aux_defined = self.training and self.aux_logits
    206         if torch.jit.is_scripting():

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in _forward(self, x)
    141     def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
    142         # N x 3 x 299 x 299
--> 143         x = self.Conv2d_1a_3x3(x)
    144         # N x 32 x 149 x 149
    145         x = self.Conv2d_2a_3x3(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torchvision/models/inception.py in forward(self, x)
    474 
    475     def forward(self, x: Tensor) -> Tensor:
--> 476         x = self.conv(x)
    477         x = self.bn(x)
    478         return F.relu(x, inplace=True)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    441 
    442     def forward(self, input: Tensor) -> Tensor:
--> 443         return self._conv_forward(input, self.weight, self.bias)
    444 
    445 class Conv3d(_ConvNd):

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    438                             _pair(0), self.dilation, self.groups)
    439         return F.conv2d(input, weight, bias, self.stride,
--> 440                         self.padding, self.dilation, self.groups)
    441 
    442     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [32, 4, 3, 3], expected input[5, 3, 299, 299] to have 4 channels, but got 3 channels instead

标签: pythonimage-processingdeep-learningcomputer-visionpytorch

解决方案


我发现pretrained=True在网络之前应用了 imagenet 数据集的归一化过滤器,这可以在 这里看到。该滤波器是为 3 通道输入图像设计的。这就是发生错误的原因。

def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
   ...
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
   ...
   return Inception3(**kwargs)


class Inception3(nn.Module):
   def __init__(self,num_classes: int = 1000,aux_logits: bool = True,transform_input: bool = False,inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,init_weights: Optional[bool] = None) -> None:
      super(Inception3, self).__init__()
      ...
      self.transform_input = transform_input
      ...

   def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x
   def forward(self, x: Tensor) -> InceptionOutputs:
        x = self._transform_input(x)
      ...

我终于可以通过以下方式使用预训练模型。

x=torch.randn((5,4,299,299))
model_ft=models.inception_v3(pretrained=True)
model_ft.transform_input=False
model_ft.Conv2d_1a_3x3.conv=nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
out=model_ft(x)

推荐阅读