tensorflow - 哪种 tensorflow 方法确实决定特定批次的示例供模型学习?
问题描述
我正在尝试了解 SGD 在 tensorflow 中的实现。
由于文件名,我从gradient_descent.py开始。
根据keras doc,优化器需要实现_resource_apply_dense
方法,该方法对应于下面显示的代码(部分):
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
return gen_training_ops.ResourceApplyKerasMomentum(
...
我想知道谁将var
变量传递给_resource_apply_dense
方法?换句话说,哪种方法决定这批特定的示例是供模型学习的?
解决方案
检查optimizer_v2或 tensorflow keras,我们发现在整个 tensorflow 代码库中唯一使用了这个函数:
#...
def apply_grad_to_update_var(var, grad):
#...
if "apply_state" in self._dense_apply_args:
apply_kwargs["apply_state"] = apply_state
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
if var.constraint is not None:
with ops.control_dependencies([update_op]):
return var.assign(var.constraint(var))
我们稍后在同一个文件中看到该var
变量来自_distributed_apply
函数的参数:
#...
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
#...
with name_scope_only_in_function_or_graph(name or self._name):
for grad, var in grads_and_vars:
#...
最后,参数在函数grads_and_vars
中定义为:List of (gradient, variable) pairs
apply_gradients
#...
def apply_gradients(self,
grads_and_vars,
#...
"""...
Args:
grads_and_vars: List of (gradient, variable) pairs.
"""
如果您检查apply_gradients
( this search ) 的出现,您会发现这是更新网络权重的常用方法,因此受优化器的“更新”步骤控制。
推荐阅读
- ios - 使用 Swift UI 更改文本字段时动画颜色变化
- c - 向客户提供带有未加密 h 文件的加密 c 代码
- mysql - 如何在 MySQL 中正确分组
- odoo - Odoo v14:通过付款创建的日记帐分录的名称前缀始终为 BNK1
- react-native - React Native)如何在地图函数中包装的元素之间添加边距?
- postgresql - Hasura 按日期排序
- graphql - 如何使用变量来选择 graphql 对象嵌套字段?
- python - 根据输入百分比拆分列表范围
- rating - 如果您在星级评分中使用 bootstrap 5 图标,则存在半星问题
- sftp - 通过jsch将zip文件下载到sftp站点时,输入流随机关闭错误