首页 > 解决方案 > 加载预训练模型时 Pytorch 大小不匹配

问题描述

有人可以帮我解决这个问题,因为我无法做到这一点。

我想加载我的预训练模型以进行强化学习。这就是我保存和加载模型的方式:

def save(self, folder_to_save='./'):
    if folder_to_save[-1] != '/':
        folder_to_save = folder_to_save + '/'
    torch.save(self.enc.state_dict(), folder_to_save + 'enc.model')
    torch.save(self.dec.state_dict(), folder_to_save + 'dec.model')
    torch.save(self.lp.state_dict(), folder_to_save + 'lp.model')

    pickle.dump(self.lp.order, open(folder_to_save + 'order.pkl', 'wb'))

def load(self, folder_to_load='./'):
    if folder_to_load[-1] != '/':
        folder_to_load = folder_to_load + '/'

    order = pickle.load(open(folder_to_load + 'order.pkl', 'rb'))
    self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
                 tt_int=self.tt_int, tt_type=self.tt_type,
                 order=order)

    self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model'))
    self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model'))
    self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model'))

在我使用的 .ipynb 文件中:

model.save('./saved_gentrl/')
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();
model.load('saved_gentrl/')
model.cuda();

但是会发生错误:

RuntimeError: Error(s) in loading state_dict for RNNEncoder:size mismatch for embs.weight: copying a param with shape torch.Size([28, 256]) from checkpoint, the shape in current model is torch.Size([31, 256]).

回溯是:

/Colab Notebooks/GENTRL-master/gentrl/gentrl.py in load(self, folder_to_load)
114                      order=order)
115 
116         self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model'))
117         self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model'))
118         self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model'))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)

1050         load = None  # break load->load reference cycle
1052         return _IncompatibleKeys(missing_keys, unexpected_keys)
1053 
1054     def _named_members(self, get_members_fn, prefix='', recurse=True):

我猜这可能是因为我在token list中增加了3个特征,导致vocab_size从28变为31,但是这个参数只用作:</p>

self.embs = nn.Embedding(get_vocab_size(), hidden_size)
self.vocab_size = get_vocab_size()

我不知道为什么,请帮助我。谢谢!

标签: pythonpytorch

解决方案


推荐阅读