首页 > 解决方案 > pytorch 模型加载和预测,AttributeError: 'dict' object has no attribute 'predict'

问题描述

model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)

> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'

如何加载已保存的 pytorch 模型检查点,并将其用于预测。我将模型保存在 .pt 扩展名中

标签: python-3.xmachine-learningpytorch

解决方案


您保存的检查点通常是一个state_dict:包含训练权重值的字典 - 但不是网络的实际架构。网络的实际计算图/架构被描述为一个 python 类(派生自nn.Module)。
要使用经过训练的模型,您需要:

  1. model从实现计算图的类中 实例化 a 。
  2. 加载保存state_dict到该实例:

    model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    

推荐阅读