pytorch - 平均后如何更新 PyTorch 模型参数(张量)?
问题描述
我目前正在研究分布式联邦学习基础设施,并正在尝试实现 PyTorch。为此,我还需要联合平均,它对从所有节点检索到的参数进行平均,然后将这些参数传递给下一轮训练。
参数的收集如下所示:
def RPC_get_parameters(data, model):
"""
Get parameters from nodes
"""
with torch.no_grad():
for parameters in model.parameters():
# store parameters in dict
return {"parameters": parameters}
在中央服务器上发生的平均函数如下所示:
# stores results from RPC_get_parameters() in results
results = client.get_results(task_id=task.get("id"))
# averaging of returned parameters
global_sum = 0
global_count = 0
for output in results:
global_sum += output["parameters"]
global_count += len(global_sum)
#
averaged_parameters = global_sum/global_count
#
new_params = {'averaged_parameters': averaged_parameters}
现在我的问题是,如何从这里更新 Pytorch 中的所有参数(张量)?我尝试了一些事情,当将 new_params 插入优化器时,它们通常返回诸如“值错误:无法优化非叶张量”之类的错误,其中通常 model.parameters() go optimizerD = optim.SGD(new_params, lr=0.01 , 动量 = 0.5)。那么我如何实际更新模型以使其使用平均参数呢?
谢谢!
https://github.com/simontkl/torch-vantage6/blob/fed_avg-w/local_dp/v6-ppsdg-py/master.py
解决方案
我认为使用参数(在 SGD 上下文之外)最方便的方法是使用state_dict
模型的。
new_params = OrderedDict()
n = len(clients) # number of clients
for client_model in clients:
sd = client_model.state_dict() # get current parameters of one client
for k, v in sd.items():
new_params[k] = new_params.get(k, 0) + v / n
之后new_params
是state_dict
(您可以使用加载它.load_state_dict
)具有客户端的平均权重。
推荐阅读
- android - 为什么图像的imageview被剪切?(在动画期间,imageview 在屏幕之外的一半)
- python - 在张量流中包裹卷积
- c++ - 从 std::string 到 std::vector 的快速转换
- ios - 为什么 didSelect 函数仅适用于视图控制器中的两个集合视图之一?
- python - 我是否能够将来自许多不同抓取网站的数据合并到一个 csv 文件中?
- php - 为什么变量仍然通过(名称验证)?*初学者*
- flutter - 如何在flutter中重新生成ios和android文件夹-我想重命名项目名称(com.example.projectname)
- log4cxx - 如何使用 LOG4CXX 库记录 Δ 字符
- javascript - 如何从具有模式的数组中生成随机数?
- javascript - 指数递归