python - 使用 tf.function 时采用渐变
问题描述
我对以下示例中观察到的行为感到困惑:
import tensorflow as tf
@tf.function
def f(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
def fplain(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
a = tf.Variable([[0., 1.], [1., 0.]])
with tf.GradientTape() as tape:
b, c = f(a)
print('tf.function gradient: ', tape.gradient([b], [c]))
# outputs: tf.function gradient: [None]
with tf.GradientTape() as tape:
b, c = fplain(a)
print('plain gradient: ', tape.gradient([b], [c]))
# outputs: plain gradient: [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
# [6., 2.]], dtype=float32)>]
较低的行为是我所期望的。我如何理解 @tf.function 案例?
非常感谢您!
(请注意,这个问题不同于:使用 tf.function 时缺少梯度,因为这里所有的计算都在函数内部。)
解决方案
梯度磁带不记录将@tf.function
函数视为一个整体而生成的 tf.Graph 内部的操作。粗略地说,f
应用于a
,梯度磁带记录了 的输出f
相对于输入的梯度a
(它是唯一观察到的变量,tape.watched_variables()
)。
在第二种情况下,没有生成图,并且在 Eager 模式下应用操作。所以一切都按预期工作。
一个好的做法是在@tf.function
(通常是训练循环)中包装一个计算成本最高的函数。在你的情况下,它会像:
@tf.function
def f(a):
with tf.GradientTape() as tape:
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
grads = tape.gradient([b], [c])
print('tf.function gradient: ', grads)
return grads
推荐阅读
- apache-spark - 在 pyspark 中使用 when 语句 - 当我添加到脚本的各个部分时不起作用
- reactjs - 在 Apollo Mutation 中调用 Meteor 方法
- google-cloud-platform - 通过 Dataflow 管道写入 Cloud SQL 非常慢
- if-statement - rpm 有条件的子字符串?
- c++ - 如何在 Visual C++ 中使用仅标头库?
- javascript - 为什么 Jquery onclick 监听器不起作用?
- html - 图像宽度被其他东西覆盖
- spring-boot - Jackson Object Mapper 在提供扩展配置时不工作,但在 Spring Boot 中提供类级别/字段级别注释时工作
- ansible - 使用 Ansible 将启动磁盘(如果存在)附加到 Gcloud 实例
- python - Kivy 数字时钟问题