首页 > 解决方案 > 如何让工作机器在本地使用 RRefs,同时在 PyTorch RPC 中保留梯度

问题描述

假设我有一个 PyTorch 模块,其组件是RRef参数服务器中保存的模型的一个组件,例如以下示例:

class MyRPCModule(nn.Module):
    def __init__(self, ps, x_dim, h_dim, z_dim):
        super(MyRPCModule, self).__init__()

        # Ref for remote embedder on param server
        self.remote_emb_rref = rpc.remote(
            ps, Embedder, args=(x_dim, h_dim)
        )

Embedder 模块本身非常小,但我有大量的训练数据,我想将它们分散到多个不同的进程/机器上。我的forward方法需要是什么样子才能使该Embedder模块上的计算发生在机器调用上forward?我试过了

def forward(self, x):
    return self.remote_emb_rref.to_here()(x)

但这不会保留渐变,因此所有的backwardandstep调用都不会更改参数服务器上的参数。

我主要在这两个教程中工作:PyTorch 的“Dist RPC 框架入门”和他们的Combining DDP with RPC Framework但都使用参数服务器进行计算。有什么办法解决这个问题吗?

标签: pytorchdistributed-computingrpc

解决方案


推荐阅读