首页 > 解决方案 > 哪种 tensorflow 方法确实决定特定批次的示例供模型学习?

问题描述

我正在尝试了解 SGD 在 tensorflow 中的实现。

由于文件名,我从gradient_descent.py开始。

根据keras doc,优化器需要实现_resource_apply_dense方法,该方法对应于下面显示的代码(部分)

def _resource_apply_dense(self, grad, var, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = ((apply_state or {}).get((var_device, var_dtype))
                    or self._fallback_apply_state(var_device, var_dtype))

    if self._momentum:
    momentum_var = self.get_slot(var, "momentum")
    return gen_training_ops.ResourceApplyKerasMomentum(
        ...

我想知道谁将var变量传递给_resource_apply_dense方法?换句话说,哪种方法决定这批特定的示例是供模型学习的?

标签: tensorflow

解决方案


检查optimizer_v2或 tensorflow keras,我们发现在整个 tensorflow 代码库中唯一使用了这个函数:

   #...
   def apply_grad_to_update_var(var, grad):
      #...
      if "apply_state" in self._dense_apply_args:
        apply_kwargs["apply_state"] = apply_state
      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))

我们稍后在同一个文件中看到该var变量来自_distributed_apply函数的参数:

#...
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
    #...
    with name_scope_only_in_function_or_graph(name or self._name):
      for grad, var in grads_and_vars:
      #...

最后,参数在函数grads_and_vars中定义为:List of (gradient, variable) pairsapply_gradients

  #...
  def apply_gradients(self,
                      grads_and_vars,
    #...
    """...
    Args:
      grads_and_vars: List of (gradient, variable) pairs.
    """

如果您检查apply_gradients( this search ) 的出现,您会发现这是更新网络权重的常用方法,因此受优化器的“更新”步骤控制。


推荐阅读