首页 > 解决方案 > 尝试理解 AutoGraph 和 tf.function: print loss in tf.function

问题描述

def train_one_step():
    with tf.GradientTape() as tape:
        a = tf.random.normal([1, 3, 1])
        b = tf.random.normal([1, 3, 1])
        loss = mse(a, b)

    tf.print('inner tf print', loss)
    print("inner py print", loss)

    return loss


@tf.function
def train():
    loss = train_one_step()

    tf.print('outer tf print', loss)
    print('outer py print', loss)

    return loss

loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)

我试图更多地了解 tf.functional 。我用不同的方法在四个地方打印了损失。它会产生这样的结果

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
  1. tf.print 和 python 打印有什么区别?
  2. 看起来 python print 将在图形评估期间执行,但 tf print 仅在执行时执行?
  3. 以上仅在有 tf.function 装饰器时适用?除此之外,tf.print 在 python print 之前运行?

标签: tensorflowtensorflow2.0

解决方案


我在一篇由三部分组成的文章中涵盖并回答了您的所有问题:“分析 tf.function 以发现 AutoGraph 的优势和微妙之处”:第 1部分、第 2部分、第 3 部分

总结并回答您的 3 个问题:

  • tf.print 和 python 打印有什么区别?

tf.print是一个 Tensorflow 结构,默认打印标准错误,更重要的是,它在评估时会产生一个操作。

当一个操作运行时,在急切执行中,它或多或少地以与 Tensorflow 1.x 相同的方式产生一个“节点”。

tf.function能够捕获生成的操作tf.print并将其转换为图形节点。

相反,print是一个 Python 构造,默认打印在标准输出上,执行时不生成任何操作。因此,tf.function无法将其转换为等效图形,只能在函数跟踪期间执行。

  • 看起来 python print 将在图形评估期间执行,但 tf print 仅在执行时执行?

我在前一点已经回答了这个问题,但再一次,print仅在函数跟踪期间tf.print执行,而在跟踪期间和执行其图形表示时都执行(在tf.function成功将函数转换为图形之后)。

  • 以上仅在有 tf.function 装饰器时适用?除此之外,tf.print 在 python print 之前运行?

是的。tf.print之前或之后不运行print。在 Eager Execution 中,一旦 Python 解释器找到语句,它们就会被评估。急切执行的唯一区别是输出流。

无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了tf.function.


推荐阅读