首页 > 解决方案 > 如何在 PyTorch 中保存模型架构?

问题描述

我知道我可以通过torch.save(model.state_dict(), FILE)或保存模型torch.save(model, FILE)。但是它们都没有保存模型的架构。

那么我们如何在 PyTorch 中保存模型的架构,就像在 Tensorflow 中创建.pb文件一样?我想对我的模型应用不同的调整。如果我无法保存模型的体系结构,是否有比每次复制整个类定义并创建一个新类更好的方法?

标签: pytorch

解决方案


你可以参考这篇文章来了解如何保存分类器。要对模型进行调整,您可以创建一个新模型,它是现有模型的子模型。


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

通过这种设置,newModel 具有所有层以及oldModelClass. 如果需要进行调整,可以在__init__函数中定义新层,然后编写一个新的 forward 函数来定义它。


推荐阅读