首页 > 解决方案 > Pytorch 1.0:net.to(device) 在 nn.DataParallel 中做了什么?

问题描述

pytorch data paraleelism教程中的以下代码对我来说很奇怪:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

据我所知,mode.to(device)将数据复制到 GPU。

DataParallel 自动拆分您的数据并将工作订单发送到多个 GPU 上的多个模型。在每个模型完成其工作后,DataParallel 会收集并合并结果,然后再将其返回给您。

如果DataParallel做复制的工作,to(device)这里做什么?

标签: deep-learningpytorch

解决方案


他们在教程中添加了几行来解释nn.DataParallel

DataParallel 自动拆分您的数据,并使用数据将作业订单发送到不同 GPU 上的多个模型。在每个模型完成其工作后,DataParallel 会为您收集并合并结果。

上面的引用可以理解nn.DataParallel为只是一个包装类,通知model.cuda()应该多拷贝到GPU。

就我而言,我的笔记本电脑上没有任何 GPU。我仍然打电话nn.DataParallel()没有任何问题。

import torch
import torchvision

model = torchvision.models.alexnet()
model = torch.nn.DataParallel(model)
# No error appears if I don't move the model to `cuda`

推荐阅读