首页 > 解决方案 > 如何在没有直接引用其类定义且没有莳萝的情况下使用 torch.load/save 腌制自定义对象?

问题描述

我有一个非常简单的场景。我有 1 个自定义类(不是 NN 模型)我正在酸洗torch.save。当我尝试加载它时失败并出现错误:

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/Users/brando/anaconda3/envs/metalearning/lib/python3.8/site-packages/torch/serialization.py", line 594, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/Users/brando/anaconda3/envs/metalearning/lib/python3.8/site-packages/torch/serialization.py", line 853, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'DagDataPreparation' on <module '__main__' from '/Users/brando/ML4Coq/ml4coq-proj/data_lib/dag/dag_dataloader.py'>

但是该类基本上是空的(尽管它位于不同的文件中DagDataPreparation

class DagDataPreparation:

    def __init__(self, root):
        print('1')

    def create_everything(self):
        return 1

如果我dill用作 torch.load 和 torch.save 的参数,它可以工作。如果我只是在文件顶部导入类,它也可以工作

from data_lib.dag.dataset_preparation import DagDataPreparation

我明白为什么会发生错误,它找不到类的定义......但这真的很奇怪,因为我 200% 我已经做过这种事情(使用火炬保存/加载任意类)和我以前从来没有遇到过这个问题。

可能是什么问题?如果您有两个文件一个加载pickle文件并且它向您抛出该错误,那么其他人是否在我的场景中遇到相同的错误?

标签: pytorch

解决方案


好吧,如果它在过去有效,那是因为在加载腌制对象时,您的模型定义可以从当前范围以某种方式访问​​,否则它将崩溃。这不是一个错误,而是泡菜的工作原理。

我想你已经看到了关于那个问题的 pytorch 建议,它不鼓励保存整个模型,而不是只保存你刚才提到的同一问题的权重,但以防万一,这里有链接

另请参阅此答案答案,它告诉 pickle 通过保存对类的引用来工作,而 dill 可以保存类定义,因此可以在加载时访问它。


推荐阅读