python - 加载预训练模型时 Pytorch 大小不匹配
问题描述
有人可以帮我解决这个问题,因为我无法做到这一点。
我想加载我的预训练模型以进行强化学习。这就是我保存和加载模型的方式:
def save(self, folder_to_save='./'):
if folder_to_save[-1] != '/':
folder_to_save = folder_to_save + '/'
torch.save(self.enc.state_dict(), folder_to_save + 'enc.model')
torch.save(self.dec.state_dict(), folder_to_save + 'dec.model')
torch.save(self.lp.state_dict(), folder_to_save + 'lp.model')
pickle.dump(self.lp.order, open(folder_to_save + 'order.pkl', 'wb'))
def load(self, folder_to_load='./'):
if folder_to_load[-1] != '/':
folder_to_load = folder_to_load + '/'
order = pickle.load(open(folder_to_load + 'order.pkl', 'rb'))
self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
tt_int=self.tt_int, tt_type=self.tt_type,
order=order)
self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model'))
self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model'))
self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model'))
在我使用的 .ipynb 文件中:
model.save('./saved_gentrl/')
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();
model.load('saved_gentrl/')
model.cuda();
但是会发生错误:
RuntimeError: Error(s) in loading state_dict for RNNEncoder:size mismatch for embs.weight: copying a param with shape torch.Size([28, 256]) from checkpoint, the shape in current model is torch.Size([31, 256]).
回溯是:
/Colab Notebooks/GENTRL-master/gentrl/gentrl.py in load(self, folder_to_load)
114 order=order)
115
116 self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model'))
117 self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model'))
118 self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model'))
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1050 load = None # break load->load reference cycle
1052 return _IncompatibleKeys(missing_keys, unexpected_keys)
1053
1054 def _named_members(self, get_members_fn, prefix='', recurse=True):
我猜这可能是因为我在token list中增加了3个特征,导致vocab_size从28变为31,但是这个参数只用作:</p>
self.embs = nn.Embedding(get_vocab_size(), hidden_size)
self.vocab_size = get_vocab_size()
我不知道为什么,请帮助我。谢谢!
解决方案
推荐阅读
- apache - 如何从 apache conf 中的响应标头中删除“Content-Type: text/html;charset=UTF-8”
- permissions - Google DLP 用户的访问被拒绝
- python - 与列表切片相反,元组切片不返回新对象
- powershell - 如何将开关参数作为变量/通过 PowerShell 中的 splatting 传递?
- amazon-web-services - 向 Application Load Balancer (ALB) 注册 EC2 实例时出现问题
- http - 确定 content-type 和 content-encoding,仍然得到 415 Unsupported media type
- java - 在不知道其列的情况下将值插入动态 mysql 表
- java - 使用 Java GDAL API 变形
- unix - 用字符“替换第一列的重复值
- python - 通过按下按钮克隆标签、按钮和文本框