首页 > 解决方案 > 平均后如何更新 PyTorch 模型参数(张量)?

问题描述

我目前正在研究分布式联邦学习基础设施,并正在尝试实现 PyTorch。为此,我还需要联合平均,它对从所有节点检索到的参数进行平均,然后将这些参数传递给下一轮训练。

参数的收集如下所示:

def RPC_get_parameters(data, model):
    """
    Get parameters from nodes
    """

    with torch.no_grad():
        for parameters in model.parameters():
            # store parameters in dict
            return {"parameters": parameters}

在中央服务器上发生的平均函数如下所示:

# stores results from RPC_get_parameters() in results

results = client.get_results(task_id=task.get("id"))

# averaging of returned parameters
global_sum = 0
global_count = 0
    
for output in results:
    global_sum += output["parameters"]
    global_count += len(global_sum)
    #
    averaged_parameters = global_sum/global_count
    #
    new_params = {'averaged_parameters': averaged_parameters}

现在我的问题是,如何从这里更新 Pytorch 中的所有参数(张量)?我尝试了一些事情,当将 new_params 插入优化器时,它们通常返回诸如“值错误:无法优化非叶张量”之类的错误,其中通常 model.parameters() go optimizerD = optim.SGD(new_params, lr=0.01 , 动量 = 0.5)。那么我如何实际更新模型以使其使用平均参数呢?

谢谢!

https://github.com/simontkl/torch-vantage6/blob/fed_avg-w/local_dp/v6-ppsdg-py/master.py

标签: pytorch

解决方案


我认为使用参数(在 SGD 上下文之外)最方便的方法是使用state_dict模型的。


new_params = OrderedDict()

n = len(clients)  # number of clients

for client_model in clients:
  sd = client_model.state_dict()  # get current parameters of one client
  for k, v in sd.items():
    new_params[k] = new_params.get(k, 0) + v / n

之后new_paramsstate_dict(您可以使用加载它.load_state_dict)具有客户端的平均权重。


推荐阅读