pytorch - 如何在没有直接引用其类定义且没有莳萝的情况下使用 torch.load/save 腌制自定义对象?
问题描述
我有一个非常简单的场景。我有 1 个自定义类(不是 NN 模型)我正在酸洗torch.save
。当我尝试加载它时失败并出现错误:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "<input>", line 1, in <module>
File "/Users/brando/anaconda3/envs/metalearning/lib/python3.8/site-packages/torch/serialization.py", line 594, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/Users/brando/anaconda3/envs/metalearning/lib/python3.8/site-packages/torch/serialization.py", line 853, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'DagDataPreparation' on <module '__main__' from '/Users/brando/ML4Coq/ml4coq-proj/data_lib/dag/dag_dataloader.py'>
但是该类基本上是空的(尽管它位于不同的文件中DagDataPreparation
:
class DagDataPreparation:
def __init__(self, root):
print('1')
def create_everything(self):
return 1
如果我dill
用作 torch.load 和 torch.save 的参数,它可以工作。如果我只是在文件顶部导入类,它也可以工作
from data_lib.dag.dataset_preparation import DagDataPreparation
我明白为什么会发生错误,它找不到类的定义......但这真的很奇怪,因为我 200% 我已经做过这种事情(使用火炬保存/加载任意类)和我以前从来没有遇到过这个问题。
可能是什么问题?如果您有两个文件一个加载pickle文件并且它向您抛出该错误,那么其他人是否在我的场景中遇到相同的错误?