python - 在 PyTorch 中加载迁移学习模型进行推理的正确方法是什么?
问题描述
我正在使用基于 Resnet152 的迁移学习来训练模型。基于 PyTorch 教程,我在保存训练模型并加载它进行推理方面没有问题。但是,加载模型所需的时间很慢。我不知道我是否正确,这是我的代码:
要将训练好的模型保存为状态字典:
torch.save(model.state_dict(), 'model.pkl')
加载它进行推理:
model = models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
st = torch.load('model.pkl', map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(st)
model.eval()
我对代码进行了计时,发现第一行model = models.resnet152()
加载时间最长。在 CPU 上,测试一张图像需要 10 秒。所以我的想法是这可能不是加载它的正确方法?
如果我像这样保存整个模型而不是 state.dict:
torch.save(model, 'model_entire.pkl')
并像这样测试它:
model = torch.load('model_entire.pkl')
model.eval()
在同一台机器上测试一张图像只需 5 秒。
所以我的问题是:这是加载 state_dict 进行推理的正确方法吗?
解决方案
在第一个代码片段中,您正在从 TorchVision 下载一个模型(具有随机权重),然后将您的(本地存储的)权重加载到它。
在第二个示例中,您正在加载本地存储的模型(及其权重)。
前者会更慢,因为您需要连接到托管模型的服务器并下载它,而不是本地文件,但它更可复制而不依赖于您的本地文件。此外,时间差应该是一次性初始化,并且它们应该具有相同的时间复杂度(在您执行推理时,模型已经在两者中加载,并且它们是等效的)。
推荐阅读
- wordpress - 在 Azure 应用程序网关后面访问 Wordpress
- java - 如何使用 Netbeans GUI 编辑器检查单击了哪个按钮
- javascript - jQuery Ajax 不在 Safari 上发布
- javascript - textarea 的 .resize() 函数
- c# - C# 泛型和派生
- powerbi - Dax Studio 中的简单 LOOKUPVALUE 错误
- neural-network - lstm中最大池化对情感分析的意义
- javascript - 方法销毁数据表
- pygame - Pygame:模拟触发初始值不是中性触发位置
- html - 有没有办法使用 ES 模块获得类似 iframe 的功能