首页 > 解决方案 > 分布式 TensorFlow:谁应用参数更新?

问题描述

我用过TensorFlow但对分布式TensorFlow训练模型很陌生。我的理解是,当前的最佳实践有利于具有异步更新的数据并行模型:

Google Brain 团队在 2016 年 4 月发表的一篇论文对各种方法进行了基准测试,发现使用少量备用副本进行同步更新的数据并行性是最有效的,不仅收敛速度更快,而且可以生成更好的模型。-使用 Scikit-Learn 和 Tensorflow 进行机器学习的第 12 章 。

现在,我在进一步阅读此架构时感到困惑的是,要弄清楚哪个组件应用了参数更新:worker 还是参数服务器?

在我下面的插图中,我很清楚工人计算梯度dJ/dw(损失 J 相对于参数权重 w 的梯度)。但是谁应用梯度下降更新规则呢?

在此处输入图像描述

有点令人困惑的是,这篇关于分布式 TensorFlow 的 O'Reilly 文章陈述了以下内容:

在更集中的架构中,设备以梯度的形式将其输出发送到参数服务器。这些服务器收集并聚合梯度。在同步训练中,参数服务器计算模型的最新版本,并将其发送回设备。在异步训练中,参数服务器将梯度发送到本地计算新模型的设备。在这两种架构中,循环都会重复,直到训练终止。

上一段建议在异步训练中:

  1. 工作人员计算梯度并将其发送到参数服务器。
  2. 参数服务器将梯度广播给工作人员。
  3. 每个工作人员接收广播的梯度并应用更新规则。

我的理解正确吗?如果是,那么这对我来说似乎不是很异步,因为工作人员必须等待参数服务器广播渐变。任何解释将不胜感激。

标签: tensorflowmachine-learning

解决方案


我意识到这是在 2018 年提出的,但让我们试一试。

  1. 每个 Worker 计算梯度
  2. 当工作人员完成计算梯度时,它会将其发送到参数服务器。
  3. 然后,worker 从参数服务器收到新参数,而无需等待其他worker。

在同步部分,在每个工作人员将其更新发送到服务器之前,工作人员不会继续训练。

这在异步情况下意味着每个工作人员可以有稍微不同的梯度,因为他们正在获取梯度而不等待每个工作人员更新参数服务器。


推荐阅读