首页 > 解决方案 > 使用 torch DistributedDataParallel 时是否可以有一个中心函数/对象?

问题描述

据我阅读,DistributedDataParallel并将DistributedSampler模型/数据加载器分别复制和拆分到每个进程,并且该进程只看到数据加载器的那一部分。该进程的调用就像进程 IDmain(rank, *args)在哪里一样。rank我有一件事想尝试,但我搜索了文档,但我找不到这是否可行。

我目前正在 ImageNet 上训练模型。我有一段代码创建一个对象,我希望它用作所有进程的中心对象,类似于DataParallel. 对于每批数据,我需要向中心对象发送一些信息,然后它返回一些操作以继续训练模型。但是,我不能为每个进程创建单独的对象,因为它会破坏并改变方法的性质。这是我想要它做的事情:

def main(rank, *args):
    model = DistributedDataParallel(...)
    dataloader = DataLoader(...)
    for inputs, labels in dataloader:
        operation = send_info_to_central_object(inputs, labels)
        inputs_aug, labels_aug = do_operation(inputs, labels, operation)

        # Note that to receive inputs_aug directly from central object is doable

简而言之,是否有可能具有将每个进程连接到torch中的单个/中心对象DistributedDataParallel的功能?

标签: pytorch

解决方案


推荐阅读