首页 > 解决方案 > 张量流中的条件打印节点

问题描述

我正在寻找一种在张量流中使用条件打印节点的方法,使用下面的示例代码行,其中每 10 个循环计数,它应该在控制台中打印一些东西。但这对我不起作用。任何人都可以建议吗?

谢谢,哈米德雷萨,

epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)

标签: tensorflowprintingconditional

解决方案


我遇到了类似的问题,并用以下丑陋的解决方案解决了它。

我正在使用tf.print而不是'tf.Print',因为它已被解密:

由于 true_func 和 false_func 应该返回相同的类型和形状,所以我返回一个无意义的 const。

def true_func():
    printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
    with tf.control_dependencies([printfunc]):
        return tf.constant(1) 

def false_func():
    return tf.constant(1)

shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)

然后您可以运行该shouldprint操作或将其添加为其他操作的依赖项。


推荐阅读