python - 如何加载 .pth 文件?
问题描述
我从 [this repository][1] 获得了一个 pytorch 模型,我必须将其转换为 tflite。这是代码:
def get_torch_model(model_path):
"""
Loads state-dict into model and creates an instance.
"""
model= torch.load(model_path)
return model
# Conversion
import torch
from torchvision import transforms
import onnx
import cv2
import numpy as np
import onnx
import tensorflow as tf
import torch
from PIL import Image
import torch.onnx
image, tf_lite_image, sample_input = get_sample_input("crop.jpg")
torch_model = get_torch_model("pose_resnet_152_256x256.pth")
ONNX_FILE = "./m_model.onnx"
到这里为止,一切都很顺利。但是当我运行下面的单元格时:
torch.onnx.export(
model=torch_model,
args=sample_input,
f=ONNX_FILE,
verbose=False,
export_params=True,
do_constant_folding=False, # fold constant values for optimization
input_names=['input'],
opset_version=10,
output_names=['output']
)
onnx_model = onnx.load(ONNX_FILE)
onnx.checker.check_model(onnx_model)
完整的错误日志:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
8 input_names=['input'],
9 opset_version=10,
---> 10 output_names=['output']
11 )
12
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
274 do_constant_folding, example_outputs,
275 strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276 custom_opsets, enable_onnx_checker, use_external_data_format)
277
278
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
92 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
93 custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94 use_external_data_format=use_external_data_format)
95
96
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
677 _set_opset_version(opset_version)
678 _set_operator_export_type(operator_export_type)
--> 679 with select_model_mode_for_export(model, training):
680 val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
681 operator_export_type,
~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
79 def __enter__(self):
80 try:
---> 81 return next(self.gen)
82 except StopIteration:
83 raise RuntimeError("generator didn't yield") from None
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
36 def select_model_mode_for_export(model, mode):
37 if not isinstance(model, torch.jit.ScriptFunction):
---> 38 is_originally_training = model.training
39
40 if mode is None:
AttributeError: 'collections.OrderedDict' object has no attribute 'training'
当我使用 torch.onnx.export() 时会发生此错误。
请让我知道这里出了什么问题。我没有正确加载重量吗?如果没有,那么我如何加载模型?我不知道类或架构细节,所以我该如何使用 model.load_state_dict() ?
[1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
解决方案
pytorch 中的.pth
二进制文件不存储模型,而只存储经过训练的权重。您需要实现模型功能import
的class
(派生class
的torch.nn.Module
)。一旦你有了这个功能,你就可以加载经过训练的权重来获取模型的特定实例来使用。
推荐阅读
- django - Django REST Serializer 使用错误的模型进行序列化
- python - python文件写入程序运行时如何更新桌面上的文件大小
- javascript - 使用 d3.js 更新表数据
- c# - C#捕获从不在进程中的函数返回的异常?
- r - 如何设置仅在输入 3 时才显示集合向量的函数?
- javascript - 如果 URI 没有改变,例如在单页应用程序上,如何检测用户是否在新页面上?
- angular - Angular Kendo UI 全局访问
- php - 内连接循环通过
- git - 如何 git rebase 从另一个分支直接到 master 分支?
- javascript - 受控数字比例映射