python-3.x - pytorch 模型加载和预测,AttributeError: 'dict' object has no attribute 'predict'
问题描述
model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)
> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'
如何加载已保存的 pytorch 模型检查点,并将其用于预测。我将模型保存在 .pt 扩展名中
解决方案
您保存的检查点通常是一个state_dict
:包含训练权重值的字典 - 但不是网络的实际架构。网络的实际计算图/架构被描述为一个 python 类(派生自nn.Module
)。
要使用经过训练的模型,您需要:
model
从实现计算图的类中 实例化 a 。加载保存
state_dict
到该实例:model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
推荐阅读
- android - 有没有办法检查我当前连接的 wifi 是否与以前连接的 wifi 不同?
- sh - 如何使用 sh 脚本对 crontab 作业进行哈希处理
- angular - Cannot properly set RequestOptionsArgs of POST, in service
- google-cloud-platform - 将近线存储转换为区域存储然后定期返回的成本是多少?
- mysql - MySQL查询列出不在另一个表中的用户
- caching - Drupal 8, Browser language detection fails after some time (anonymous users)
- jquery - How to automatically resize height of visjs Network component when window resized?
- rational-team-concert - 如何在 RTC 中“git stash”?
- wix - WIX msi with large number of files takes before welcome dialog is published
- dart - How can I extract a zip file archive in dart asynchronously?