vgg-net - 如何在pytorch的修改后的vgg19网络中加载预训练的权重?
问题描述
我正在尝试使用修改后的输入通道数加载 vgg19 网络。输入通道的数量是 4 是我的情况,而且我正在将分类器更改为我自己的分类器。我还从网络中删除了自适应平均池化层。我应该如何在 PyTorch 中将预训练的权重加载到我的模型的修改版本中?
假设我的模型的修改版本在变量 myModel 中。我怎样才能将 vgg19 的预训练权重加载到相同的位置?
解决方案
选项 1. 如果要使用原始 VGG19 网络给出的原始预训练权重,则必须先加载权重,然后再修改网络。预训练的权重是为原始网络定义的,因此它需要匹配输入通道。然后您可以在开头添加一个额外的层作为输入层,并在新网络中删除池化层。
选项 2。您可以分别加载除输入层之外的所有层的权重,因为会有尺寸不匹配。
在代码中它看起来像这样 -
# corresp_name is a dict object with mapping for your given layer
# name and original models layer name
p_dict = torch.load(Path.model_dir()) #p_dict is my_model
s_dict = self.state_dict()
for name in p_dict:
if name not in corresp_name:
continue
s_dict[corresp_name[name]] = p_dict[name]
self.load_state_dict(s_dict)
推荐阅读
- javascript - 使用Javascript生成多个html表单并将表单数据导出到excel
- wordpress - WordPress REST API 通过 Slug 接收单个帖子的数据?
- javascript - HTML 拖放 - 当可拖动元素未放置在有效的放置代理上时接收事件
- typescript - 如何在打字稿快递服务器中将字符串解码/编码为base64
- android - 在反应原生的特定页面上禁用幻灯片菜单[抽屉]
- android - 通过 Intent 为 WhatsApp、Telegram 等共享图像和链接
- scala - 计算数据库调用所需时间的更好方法是什么?
- android - 获取应用程序的哈希
- android - 如果使用不同的语言(英语到阿拉伯语),如何更改 TextView 的重力
- python - 如何在图形人口模型中修复“**或 pow() 不支持的操作数类型:‘builtin_function_or_method’和‘float’”?