pytorch - 向预训练模型添加参数
问题描述
在 Pytorch 中,我们加载预训练模型如下:
net.load_state_dict(torch.load(path)['model_state_dict'])
那么网络结构和加载的模型必须完全相同。但是,是否可以加载权重然后修改网络/添加额外的参数?
注意:如果我们在加载权重之前向模型添加一个额外的参数,例如
self.parameter = Parameter(torch.ones(5),requires_grad=True)
Missing key(s) in state_dict:
加载权重时会出错。
解决方案
让我们创建一个模型并保存它的状态。
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.encoder = nn.LSTM(100, 50)
def forward(self):
pass
model1 = Model1()
torch.save(model1.state_dict(), 'filename.pt') # saving model
然后创建第二个模型,它与第一个模型有几个共同的层。加载第一个模型的状态并将其加载到第二个模型的公共层。
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.encoder = nn.LSTM(100, 50)
self.linear = nn.Linear(50, 200)
def forward(self):
pass
model1_dict = torch.load('filename.pt')
model2 = Model2()
model2_dict = model2.state_dict()
# 1. filter out unnecessary keys
filtered_dict = {k: v for k, v in model1_dict.items() if k in model2_dict}
# 2. overwrite entries in the existing state dict
model2_dict.update(filtered_dict)
# 3. load the new state dict
model2.load_state_dict(model2_dict)
推荐阅读
- ruby-on-rails - Rails - 如何查询 M2M 字段
- eclipse - 带有 vrapper 的 Eclipse 无法识别文件更改
- sharepoint-online - 使用 PowerAutomate 邀请来宾用户访问 SharePoint 站点
- excel - 验证数组是否包含不在我允许的字符列表中的字符
- python - PIL 保存功能创建太大的图像
- bash - 增加文件中第一行的最后一个数字
- algorithm - 枚举大小为 k 的子集的高效算法
- java - SSLHandshakeException:为什么直到读取套接字才抛出?
- python - 具有二项式回归的 Python 和 R 的不同 GLM 结果
- linux - 我们所有的 PowerShell 脚本都是用 5.1 版本编写的。我们得到了一个只使用 Linux 服务器的新客户端。我们需要重新编写所有脚本吗?