首页 > 解决方案 > 是否可以加载使用 model.module.state_dict() 存储但使用 model.state_dict() 加载的模型

问题描述

我想问一个问题,我已经用两个gpu训练了一个模型并用model.module.state_dict()存储了这个模型,现在我想在一个gpu中加载这个模型,我可以直接用model.state_dict加载这个训练好的模型吗()?

提前致谢!

标签: neural-networkpytorch

解决方案


你可以参考这个问题。

您可以添加一个nn.DataParallel用于加载目的。或更改密钥命名,如

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

model.module.state_dict()但是当您保存模型时model.state_dict(),名称可能会有所不同。如果上述两种方法都不起作用,请尝试打印保存的 dict 和 model 以查看需要更改的内容。喜欢

state_dict = torch.load('myfile.pth.tar')
print(state_dict)
print(model)

推荐阅读