首页 > 解决方案 > 我对张量流回溯的理解是否正确?

问题描述

tensorflow doc给出了这个例子

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

这是输出

Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

我想我理解train(num_steps)定义为Function对象的第一部分。train(num_steps=10)train(num_steps=20)使用不同的 Python 值并导致回溯。

在第二部分中,train(num_steps=tf.constant(10))train(num_steps=tf.constant(20))导致重用 Traces whereprint("Tracing with num_steps = ", num_steps)正常运行,而参数部分num_steps=tf.constant(10)被 tf.Graph 捕获并且不在跟踪阶段运行。

我的理解正确吗?

标签: tensorflow

解决方案


tf.function用告诉装饰你的函数tensorflow来生成一个计算图。默认情况下,当您传递具有新形状的张量时会生成一个新图形

这会生成新图,因为它是第一次调用:

my_tf_func(tf.constant(10)) 

这不会生成新图,因为参数的形状与图的输入形状相同:

my_tf_func(tf.constant(20)) 

由于形状不同,这会生成新图形:

my_tf_func(tf.constant([10, 20])) 

当传递 python 值(不是 tf.tensors)时,每次使用不同的值进行调用时都会生成一个新图。所有这些调用都会生成新图:

my_tf_func(10) 
my_tf_func(20) 
my_tf_func(30) 

推荐阅读