neural-network - 从 Pytorch 中的序列化模型获取 state_dict 时出错
问题描述
已经训练了我的编码器-解码器模型并使用以下方法保存:
model_state = {
'encoder': encoder,
'encoder_optimizer': encoder_optimizer,
'decoder': decoder,
'decoder_optimizer': decoder_optimizer
}
torch.save(model_state, "best_model.pth.tar")
当我单独使用模型时这很好用,但是当我尝试在另一个应用程序中使用我的模型时它会给我错误。因此,我试图加载模型并将编码器和解码器保存为 state_dicts。这适用于我的编码器,但是当我尝试时:
checkpoint = torch.load(path_to_model, map_location=torch.device("cpu"))
decoder = checkpoint['decoder']
decoder = decoder.to(device)
encoder = checkpoint['encoder']
encoder = encoder.to(device)
torch.save(encoder.state_dict(), 'encoder.dict')
torch.save(decoder.state_dict(), 'decoder.dict')
它失败了torch.save(decoder.state_dict(), 'decoder.dict')
,我得到了错误:
File "<stdin>", line 1, in <module>
File "caption.py", line 31, in load_maps
torch.save(decoder.state_dict(), 'decoder.dict')
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 695, in state_dict
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 695, in state_dict
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 696, in state_dict
for hook in self._state_dict_hooks.values():
File "/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 585, in __getattr__
type(self).__name__, name))
AttributeError: 'Softmax' object has no attribute '_state_dict_hooks'
有没有办法绕过这个错误或重新创建state_dict
而无需重新训练我的模型?我不明白为什么如果我保存了我无法摆脱的整个模型state_dict
,这可能是模型的一部分。
这是调用的输出for m in decoder.modules(): print(m)
:
DecoderWithAttention(
(attention): Attention(
(encoder_att): Linear(in_features=2048, out_features=512, bias=True)
(decoder_att): Linear(in_features=512, out_features=512, bias=True)
(full_att): Linear(in_features=512, out_features=1, bias=True)
(relu): ReLU()
(softmax): Softmax(dim=1)
)
(embedding): Embedding(9490, 512)
(dropout): Dropout(p=0.5, inplace=False)
(decode_step): LSTMCell(2560, 512, bias=1)
(init_h): Linear(in_features=2048, out_features=512, bias=True)
(init_c): Linear(in_features=2048, out_features=512, bias=True)
(f_beta): Linear(in_features=512, out_features=2048, bias=True)
(sigmoid): Sigmoid()
(fc): Linear(in_features=512, out_features=9490, bias=True)
)
Attention(
(encoder_att): Linear(in_features=2048, out_features=512, bias=True)
(decoder_att): Linear(in_features=512, out_features=512, bias=True)
(full_att): Linear(in_features=512, out_features=1, bias=True)
(relu): ReLU()
(softmax): Softmax(dim=1)
)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=1, bias=True)
ReLU()
Softmax(dim=1)
Embedding(9490, 512)
Dropout(p=0.5, inplace=False)
LSTMCell(2560, 512, bias=1)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=2048, out_features=512, bias=True)
Linear(in_features=512, out_features=2048, bias=True)
Sigmoid()
Linear(in_features=512, out_features=9490, bias=True)
解决方案
尝试像这样保存模型。
torch.save({'state_dict': decoder.state_dict()}, 'decoder.pth.tar')
推荐阅读
- redirect - NGINX - Skip redirect to https with proxypass to upstream
- python - A module that has the same name as a module in standard library and also need to import the same module
- javascript - 只有内容 div 可滚动的侧边栏布局
- python - 使用 xarray 和 matplotlib 绘制年和图
- matlab - Matlab authenticate FTP with private key file
- python - 如何使 2 个列表框与相同的函数绑定,而不是在选择另一个列表框时为两个 lisbox 运行相同的函数两次?
- python - How to (or is it possible to) use a list to extract nested dictionary in Python?
- ruby-on-rails - 当 ko.yml 存在且位于正确的目录中时,出现“I18n::InvalidLocale · “ko”不是有效的语言环境”错误
- java - How to make a SimpleProperty inform listeners on the first set?
- javascript - 在 NodeJS 中使用 Jquery