tensorflow - 尝试理解 AutoGraph 和 tf.function: print loss in tf.function
问题描述
def train_one_step():
with tf.GradientTape() as tape:
a = tf.random.normal([1, 3, 1])
b = tf.random.normal([1, 3, 1])
loss = mse(a, b)
tf.print('inner tf print', loss)
print("inner py print", loss)
return loss
@tf.function
def train():
loss = train_one_step()
tf.print('outer tf print', loss)
print('outer py print', loss)
return loss
loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)
我试图更多地了解 tf.functional 。我用不同的方法在四个地方打印了损失。它会产生这样的结果
inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
- tf.print 和 python 打印有什么区别?
- 看起来 python print 将在图形评估期间执行,但 tf print 仅在执行时执行?
- 以上仅在有 tf.function 装饰器时适用?除此之外,tf.print 在 python print 之前运行?
解决方案
我在一篇由三部分组成的文章中涵盖并回答了您的所有问题:“分析 tf.function 以发现 AutoGraph 的优势和微妙之处”:第 1部分、第 2部分、第 3 部分。
总结并回答您的 3 个问题:
- tf.print 和 python 打印有什么区别?
tf.print
是一个 Tensorflow 结构,默认打印标准错误,更重要的是,它在评估时会产生一个操作。
当一个操作运行时,在急切执行中,它或多或少地以与 Tensorflow 1.x 相同的方式产生一个“节点”。
tf.function
能够捕获生成的操作tf.print
并将其转换为图形节点。
相反,print
是一个 Python 构造,默认打印在标准输出上,执行时不生成任何操作。因此,tf.function
无法将其转换为等效图形,只能在函数跟踪期间执行。
- 看起来 python print 将在图形评估期间执行,但 tf print 仅在执行时执行?
我在前一点已经回答了这个问题,但再一次,print
仅在函数跟踪期间tf.print
执行,而在跟踪期间和执行其图形表示时都执行(在tf.function
成功将函数转换为图形之后)。
- 以上仅在有 tf.function 装饰器时适用?除此之外,tf.print 在 python print 之前运行?
是的。tf.print
之前或之后不运行print
。在 Eager Execution 中,一旦 Python 解释器找到语句,它们就会被评估。急切执行的唯一区别是输出流。
无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了tf.function
.
推荐阅读
- c# - 没有为实体类型“FormFile”找到合适的构造函数
- angular7 - 为什么路由器插座中的滚动动画不起作用以及如何解决?
- python - 在 python 中使用谷歌地图 API 进行批量地理编码时出现超时错误
- python - 如何使用 for 循环在 python 中创建 Traingular 移动平均线
- java - 如何在服务器上获取客户端的IP?
- mysql - 如何在 Mysql 中使用子查询来解决这个挑战?
- javascript - 从多个输入中获取所有文件
- c# - 如何找到名称中数字较大的对象?
- python-2.7 - 如何使用 Python 从 hdf 格式的文件名中提取日期?
- c - 如何从 .txt 文件中读取并将每一行分成不同的字符串并将它们存储到不同的结构变量中?