首页 > 解决方案 > 如何在pytorch的修改后的vgg19网络中加载预训练的权重?

问题描述

我正在尝试使用修改后的输入通道数加载 vgg19 网络。输入通道的数量是 4 是我的情况,而且我正在将分类器更改为我自己的分类器。我还从网络中删除了自适应平均池化层。我应该如何在 PyTorch 中将预训练的权重加载到我的模型的修改版本中?

假设我的模型的修改版本在变量 myModel 中。我怎样才能将 vgg19 的预训练权重加载到相同的位置?

标签: vgg-netpre-trained-model

解决方案


选项 1. 如果要使用原始 VGG19 网络给出的原始预训练权重,则必须先加载权重,然后再修改网络。预训练的权重是为原始网络定义的,因此它需要匹配输入通道。然后您可以在开头添加一个额外的层作为输入层,并在新网络中删除池化层。

选项 2。您可以分别加载除输入层之外的所有层的权重,因为会有尺寸不匹配。

在代码中它看起来像这样 -

  # corresp_name is a dict object with mapping for your given layer 
  # name and original models layer name
  p_dict = torch.load(Path.model_dir()) #p_dict is my_model
  s_dict = self.state_dict()
  for name in p_dict:
      if name not in corresp_name:
            continue
      s_dict[corresp_name[name]] = p_dict[name]
  self.load_state_dict(s_dict)

推荐阅读