load - 当我使用 torch.load 时出现运行时错误“存储大小错误:”
问题描述
当我调用torch.load("pthfilename")
. 我的模型在多个 GPU 上进行了训练,我使用以下代码保存了模型:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
device = torch.device(arg.local_rank)
net = Net().to(device)
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[arg.local_rank])
torch.save(net.state_dict(), "0.pth"))
错误是:
Traceback (most recent call last):
File "/root/PycharmProjects/test.py", line 8, in <module>
model_dict = torch.load("0.pth")
File "torch/serialization.py", line 529, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "torch/serialization.py", line 709, in _legacy_load
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
RuntimeError: storage has wrong size: expected -4916312287391674656 got 24
解决方案
如果您使用多进程类型(例如:DistributedDataParallel)训练您的模型,您应该在保存模型时分配一个 local_rank。
def save_checkpoint(epoch, model, best_top5, optimizer,
is_best=False,
filename='checkpoint.pth.tar'):
state = {
'epoch': epoch+1, 'state_dict': model.state_dict(),
'best_top5': best_top5, 'optimizer' : optimizer.state_dict(),
}
torch.save(state, filename)
if args.local_rank == 0:
if is_best: save_checkpoint(epoch, model, best_top5, optimizer, is_best=True, filename='model_best.pth.tar')
推荐阅读
- python - 如何从python中的json读取字节?
- html - 使用 SVG 编码将鼠标悬停在 PNG 图像上时如何获得不同的颜色/填充/不透明度
- python - 轮廓的曲面图
- java - Java 在 x 时间段内调用 REST 服务,以定期间隔进行响应
- react-native - 试图在本机反应中将图像数据传递到另一个屏幕
- docusignapi - 调整 Docusign 的标志此处选项卡的大小
- ios - applicationDidEnterBackground 中的应用程序时未调用 applicationWillTerminate
- docker - Docker swarm 路由重定向不起作用
- excel - 宏需要很长时间才能循环近 700 行
- php - 附件文件未显示在电子邮件收件箱 php 邮件中