首页 > 解决方案 > 如果条件输出发生变化,则使用 @tf.function 进行装饰

问题描述

我正在尝试评估我的变量 a 是否为空(即 has size == 0)。但是,当用 装饰代码时@tf.function,if 语句错误地计算为 True,而当删除装饰器时,它的计算结果为 False。tf.size(a)似乎在这两种情况下都正确评估为 0。如何解决这个问题?

import tensorflow as tf
a=tf.Variable([[]])
@tf.function
def test(a):
    print_op = tf.print(tf.size(a))
    print(tf.size(a))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None
test(a)

标签: pythontensorflowtensorflow2.0

解决方案


这有点让人头疼,但是,一旦我们了解tf.function将 python 操作和控制流映射到 tf 图,而裸函数只是急切地执行,我们就可以挑选它,这更有意义。

我已经调整了你的例子来说明发生了什么。考虑test1以下test2

@tf.function
def test1(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

def test2(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

@tf.function除了装饰器之外,它们彼此相同。

现在执行test2(tf.Variable([[]]))给了我们:

0
python print size: 0

这是我认为您期望的行为。而test1(tf.Variable([[]]))给出:

python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0

fail关于这个输出,有几件事(除了)你可能会感到惊讶:

  • print()语句打印出一个(尚未评估的)张量而不是零
  • print()和的顺序tf.print()颠倒了

这是因为通过添加@tf.function我们不再有 python 函数,而是使用 autograph 从函数代码映射的 tf 图。这意味着,在if评估条件时,我们还没有执行tf.math.not_equal(tf.size(a),0),只有一个对象(对象的实例Tensor),在 python 中,它是真实的:

class MyClass:
  pass
my_obj = MyClass()
if (my_obj):
  print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"

这意味着我们在评估之前得到了print('fail')语句。test1tf.math.not_equal(tf.size(a),0)

那么解决方法是什么?

好吧,如果我们print()在块中删除对仅 python 函数的调用if并将其替换为对签名友好的tf.print()语句,那么签名将无缝地将我们的if ... else ...逻辑转换为对图形友好的tf.cond语句,以确保一切都以正确的顺序发生:

定义测试3(a):
    print_op = tf.print(tf.size(a))
    print("python 打印尺寸:{}".format(tf.size(a)))
    如果 tf.math.not_equal(tf.size(a),0):
        tf.print('失败')
    使用 tf.control_dependencies([print_op]):
        返回无
test3(tf.Variable([[]]))
0
python print size: 0

推荐阅读