首页 > 解决方案 > 如何加载 .pth 文件?

问题描述

我从 [this repository][1] 获得了一个 pytorch 模型,我必须将其转换为 tflite。这是代码:

def get_torch_model(model_path):
    """
    Loads state-dict into model and creates an instance.
    """
    model= torch.load(model_path)
    return model
# Conversion
import torch
from torchvision import transforms

import onnx

import cv2
import numpy as np
import onnx
import tensorflow as tf
import torch
from PIL import Image

import torch.onnx

image, tf_lite_image, sample_input = get_sample_input("crop.jpg")
torch_model = get_torch_model("pose_resnet_152_256x256.pth")

ONNX_FILE = "./m_model.onnx"

到这里为止,一切都很顺利。但是当我运行下面的单元格时:

torch.onnx.export(
        model=torch_model,
        args=sample_input,
        f=ONNX_FILE,
        verbose=False,
        export_params=True,
        do_constant_folding=False,  # fold constant values for optimization
        input_names=['input'],
        opset_version=10,
        output_names=['output']
)

onnx_model = onnx.load(ONNX_FILE)

onnx.checker.check_model(onnx_model)

完整的错误日志:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
      8         input_names=['input'],
      9         opset_version=10,
---> 10         output_names=['output']
     11 )
     12 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    274                         do_constant_folding, example_outputs,
    275                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276                         custom_opsets, enable_onnx_checker, use_external_data_format)
    277 
    278 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     92             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
     93             custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94             use_external_data_format=use_external_data_format)
     95 
     96 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    677         _set_opset_version(opset_version)
    678         _set_operator_export_type(operator_export_type)
--> 679         with select_model_mode_for_export(model, training):
    680             val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
    681                                                              operator_export_type,

~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
     79     def __enter__(self):
     80         try:
---> 81             return next(self.gen)
     82         except StopIteration:
     83             raise RuntimeError("generator didn't yield") from None

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
     36 def select_model_mode_for_export(model, mode):
     37     if not isinstance(model, torch.jit.ScriptFunction):
---> 38         is_originally_training = model.training
     39 
     40         if mode is None:

AttributeError: 'collections.OrderedDict' object has no attribute 'training'

当我使用 torch.onnx.export() 时会发生此错误。

请让我知道这里出了什么问题。我没有正确加载重量吗?如果没有,那么我如何加载模型?我不知道类或架构细节,所以我该如何使用 model.load_state_dict() ?




  [1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch

标签: pythonpytorchonnx

解决方案


pytorch 中的.pth二进制文件不存储模型,而只存储经过训练的权重。您需要实现模型功能importclass(派生classtorch.nn.Module)。一旦你有了这个功能,你就可以加载经过训练的权重来获取模型的特定实例来使用。


推荐阅读