tensorflow - 为什么`tf.train.Optimizer().compute_gradients(loss)`也返回不在`loss`子图中的变量?
问题描述
我正在手动收集多任务模型的梯度统计信息,其图表如下所示:
input -> [body_var1 ... body_varN] --> [task1_var1 ... task1_varM] <-- loss_1
\-> [task2_var1 ... task2_varM] <-- loss_2
我为每个损失定义了一个单独的优化器,如下所示(实际代码非常复杂,以下针对这个问题进行了简化):
# for simplicity, just demonstrate the case with the 1st task
task_index = 1
# here we define the optimizer (create an instance in graph)
loss = losses[task_index]
optimizer = tf.train.GradientDescentOptimizer()
grads_and_vars = optimizer.compute_gradients(loss)
# now let's see what it returns
for g, v in grads_and_vars:
print(' grad:', g, ', var:', v)
因此,上面的代码清楚地为任务 1 的分支创建了一个单独的优化器,然后我们创建梯度计算操作optimizer.compute_gradients(loss)
并打印我们应用梯度的变量。
预期结果:
grad: body_var1_grad, var: body_var1 # \
... # --> body vars and gradients
grad: body_varN_grad, var: body_varN # /
grad: task1_var1_grad, var: task1_var1 # \
... # --> task 1 vars and gradients
grad: task1_var1_grad, var: task1_var1 # /
所以我期望优化器只包含它所应用的分支的梯度计算操作(即第一个任务的分支)
实际结果
grad: body_var1_grad, var: body_var1 # \
... # --> body vars and gradients
grad: body_varN_grad, var: body_varN # /
grad: task1_var1_grad, var: task1_var1 # \
... # --> task 1 vars and gradients
grad: task1_var1_grad, var: task1_var1 # /
grad: None, var: task2_var1 # \
... # --> task 2 vars, with None gradients
grad: None, var: task2_var1 # /
因此,它看起来optimizer.compute_gradients(loss)
不仅捕获了输出到的子图loss
(可以使用 提取tf.graph_util.extract_sub_graph
),而且还捕获了所有连接到的可训练变量,而loss
无需为它们创建梯度变量(因此返回的梯度变量为None
)。
问:这样的行为正常吗?
解决方案
是的,因为compute_gradients()计算相对于传递给参数的对象loss
列表的梯度。如果未提供,则该函数计算关于GraphKeys.TRAINABLE_VARIABLES集合中所有变量的梯度。此外,如果不依赖于某些变量,则未定义相对于这些变量的梯度,即返回。根据您提供的代码,情况似乎如此。tf.Variable
var_list
var_list
loss
loss
None
如果您只想optimizer
计算某些变量的梯度,您应该列出这些变量并将其传递给 的var_list
参数compute_gradients()
。
推荐阅读
- mysql - 查询获取上周添加的产品详情
- angular - 在响应时删除 Angular Material 分页
- jquery - 如果鼠标快速移动,则 Jquery 悬停问题
- python - 如何从全局类范围访问 self.function
- android - DialogFragment.dismiss() leaves keyboard on screen
- java - 如何在自定义 getter 中序列化 JSON 对象?
- javascript - Mobile click on body does not fire
- sas - ODBC 和 Netezza 引擎之间的区别
- r - position_dodge on geom_text on dodged barplot
- javascript - Custom JSX to React intl formattedMessage