pytorch - 用 CPU 加载 pickle 保存的 GPU 张量?
问题描述
我使用 GPU 上的 pickle 将 Bert 的最后一个隐藏层保存为我的后续过程。
# output is the last hidden layer of bert, transformed on GPU
with open(filename, 'wb') as f:
pk.dump(output, f)
是否可以在没有 GPU 的个人笔记本电脑上加载它?我尝试了以下代码,但都失败了。
# 1st try
with open(filename, 'rb') as f:
torch.load(f, map_location='cpu')
# 2nd
torch.load(filename, map_location=torch.device('cpu'))
都得到以下错误
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
是否可以在我的笔记本电脑上加载文件?
解决方案
如果您使用 pytorch,您可以通过保存state_dict
模型而不是模型本身来省去一些麻烦。这state_dict
是一个存储神经网络权重的有序字典。
保存程序:
import torch
model = MyFabulousPytorchModel()
torch.save(model.state_dict(), "best_model.pt")
加载它需要您首先初始化模型:
import torch
device = 'cuda' if torch.cuda.is_available() else 'gpu'
model = MyFabulousPytorchModel()
model.load_state_dict(torch.load(PATH_TO_MODEL))
model.device(device)
state_dict
直接保存对象而不是对象有很多优点。其中之一与您的问题有关:将您的模型移植到不同的环境并不像您希望的那样轻松。另一个优点是保存检查点要容易得多,这些检查点可以让您恢复训练,就好像训练从未停止过一样。您所要做的就是保存优化器的状态和损失:
保存检查点:
# somewhere in your training loop:
opt.zero_grad()
pred = model(x)
loss = loss_func(pred, target)
torch.save({"model": model.state_dict(), "opt": opt.state_dict(), "loss":loss}, "checkpoing.pt")
我强烈建议您查看文档以获取有关如何使用 pytorch 保存和加载模型的更多信息。如果您了解其内部工作原理,这是一个非常顺利的过程。https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference
希望有帮助=)
编辑:
更直接,为了解决您的问题,我推荐以下
1-在您用来训练模型的计算机上:
import torch
model = torch.load("PATH_TO_MODEL")
torch.save(model.state_dict(), "PATH.pt")
2-在另一台计算机上:
import torch
from FILE_WHERE_THE_MODEL_CLASS_IS_DEFINED import Model
model = Model() # initialize one instance of the model)
model.load_state_dict(torch.load("PATH.pt")
推荐阅读
- reactjs - 使用 Oauth2、React、Node.js 和 Passport.js 通过 Google 登录按钮对用户进行身份验证的最佳做法是什么?
- ms-access - 表格字段为 4 位十进制,而原始数据仅为 2 位十进制
- azerothcore - 在 Visual Studio 中构建源代码时 .obj 文件中的错误
- mongodb - 在 kali linux 中安装 mongodb 时,我在 apt-get 中遇到错误
- php - 尽管我调用了补丁方法'@method('PATCH')',但修改按钮不起作用,我想显示文档但它没有出现
- javascript - javascript在循环后不起作用它在while和for中都是一样的
- matplotlib - 排序子图中的标签,使用 fig.add_axes 创建
- java - 如何处理 Mallet 中 cmd 行中的空格?
- android - 如何为 Android Studio 制作 SDK 位置
- php - 单击按钮时如何仅更新 1 行而不是所有行