pytorch - 为优化器加载状态字典时的 Pytorch / 设备问题(cpu,gpu)
问题描述
嗨,我是从去年夏天开始学习 pytorch 的学生。
state = torch.load('drive/My Drive/MODEL/4 CBAM classifier55')
model = MyResNet()
model.load_state_dict(state['state_dict'])
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003,betas=(0.5,0.999))
optimizer.load_state_dict(state['optimizer'])
model.to(device)
我写了上面的代码。
RuntimeError Traceback (most recent call last)
<ipython-input-26-507493db387a> in <module>()
56 new_loss.backward()
57
---> 58 optimizer.step()
59
60 running_loss += loss.item()
/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
13 def decorate_context(*args, **kwargs):
14 with self:
---> 15 return func(*args, **kwargs)
16 return decorate_context
17
/usr/local/lib/python3.6/dist-packages/torch/optim/adam.py in step(self, closure)
97
98 # Decay the first and second moment running average coefficient
---> 99 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
100 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
101 if amsgrad:
RuntimeError: expected device cpu but got device cuda:0
当我实现训练代码时,我得到了这种错误。当我注释掉“optimizer.load_state_dict”时,它运行良好。我怎么解决这个问题?谢谢您的回答。:)
解决方案
似乎state
在cuda
您保存时已打开,现在尝试使用它,cpu
反之亦然。为避免此错误,一种简单的方法是将map_location
参数传递给 load。
只需通过map_location=<device you want to use>
,torch.load
它应该可以正常工作。另外,请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
推荐阅读
- oauth - Google 验证 ID 令牌的完整性得到错误:IllegalArgumentException
- android - Kotlin 与 Dagger 2
- r - 计算存在缺席矩阵中多列与单列匹配的位置
- javascript - Is there a way to create an object property from within a method of the same object?
- java - Java Mail 找不到类 MimeMessage
- outlook - 如何使用管理员模拟通过 EWS 正确创建/更新公用文件夹?
- python - Tensorflow LSTMCell 中的 cell_clip 和 proj_clip 参数
- php - Woocommerce Storefront Galleria 添加类别名称,使其永久显示在类别缩略图上
- firebase - Firebase 推送通知交付收据
- xamarin.forms - 如何在 Xamarin 表单中进行拖放