首页 > 解决方案 > 使用 tf.function 时采用渐变

问题描述

我对以下示例中观察到的行为感到困惑:

import tensorflow as tf

@tf.function
def f(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c

def fplain(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c


a = tf.Variable([[0., 1.], [1., 0.]])

with tf.GradientTape() as tape:
    b, c = f(a)
    
print('tf.function gradient: ', tape.gradient([b], [c]))

# outputs: tf.function gradient:  [None]

with tf.GradientTape() as tape:
    b, c = fplain(a)
    
print('plain gradient: ', tape.gradient([b], [c]))

# outputs: plain gradient:  [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
#        [6., 2.]], dtype=float32)>]

较低的行为是我所期望的。我如何理解 @tf.function 案例?

非常感谢您!

(请注意,这个问题不同于:使用 tf.function 时缺少梯度,因为这里所有的计算都在函数内部。)

标签: pythontensorflow2.0decoratorgradienttape

解决方案


梯度磁带不记录将@tf.function函数视为一个整体而生成的 tf.Graph 内部的操作。粗略地说,f应用于a,梯度磁带记录了 的输出f相对于输入的梯度a(它是唯一观察到的变量,tape.watched_variables())。

在第二种情况下,没有生成图,并且在 Eager 模式下应用操作。所以一切都按预期工作。

一个好的做法是在@tf.function(通常是训练循环)中包装一个计算成本最高的函数。在你的情况下,它会像:

@tf.function
def f(a):
    with tf.GradientTape() as tape:
        c = a * 2
        b = tf.reduce_sum(c ** 2 + 2 * c)
    grads = tape.gradient([b], [c])
    print('tf.function gradient: ', grads)
    return grads

推荐阅读