pytorch - Pytorch:如何将模型动物园预训练模型映射到新 GPU
问题描述
我正在尝试加载其中一个预训练模型
model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'}
当我使用以下代码时,它总是将模型加载到 cuda:0。如果我想将它加载到 cuda:3 怎么办?
model = ResNet(BasicBlock, [3, 4, 6, 3])
device = 3
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'],
map_location=lambda storage, loc: storage.cuda(device)))
解决方案
这应该为您完成工作:
device = torch.device('cuda')
model = ResNet(BasicBlock, [3, 4, 6, 3])
with torch.cuda.device(3):
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'],
map_location=lambda storage, loc: storage.cuda(device)))
我认为这适用于 0.4.0 及更高版本,您可以查看 0.4.0 中的更多示例。迁移指南: https ://pytorch.org/2018/04/22/0_4_0-migration-guide.html
推荐阅读
- python - 多处理管理器在 pool.apply_async 的非常简单的示例中失败
- python - Django 使用 .value 过滤数据库并调用外键
- go - 数据竞赛无法理解
- c - C 代码的 x86 反汇编生成:orq $0x0, %(rsp)
- apache-spark - 将 pyspark 列转换为列表
- r - 如何引用 R 中的所有其他列?
- python-3.x - 如果出现超过 1 次,如何将子标签移动到母标签之后?
- python - ValueError 视图没有返回 HttpResponse 对象。它返回 None 而不是
- javascript - 为什么 window.setTimeout 返回错误?
- reactjs - react-native undefined 不是对象