首页 > 解决方案 > 如果图表损坏,tf.Print 不起作用

问题描述

我正在尝试构建一个完全卷积的神经网络。我的问题是,在某个阶段,张量的形状不再匹配,导致异常,我想在每个步骤之后打印张量的形状,以便能够查明问题。然而问题是,如果图形被破坏并且在某个时候抛出异常(即使异常发生在管道中的 print 语句之后),tf.Print 似乎不会打印任何内容。我在打印中使用下面的代码。如果我有一个工作图,它工作正常。那么 tf.Print 真的只能用于工作图吗?如果是这种情况,我怎么能打印张量的形状,或者是唯一可能使用一些调试器,例如 tfdbg?

upsample = custom_layers.crop_center(input_layer, upsample)
upsample_print = tf.Print(upsample, [tf.shape(upsample)], "shape of tensor is ")
logits = tf.reshape(upsample_print, [-1, 2])
...

给出的错误是

ValueError: Dimension size must be evenly divisible by 2898844 but is 2005644 for 'gradients/Reshape_grad/Reshape' (op: 'Reshape') with input shapes: [1002822,2], [4] and with input tensors computed as partial shapes: input[1] = [?,1391,1042,2].

标签: pythondebuggingtensorflow

解决方案


tf.Print仅在运行时打印。它只是将一个节点添加到图形中,该节点在执行时会向控制台打印一些内容。因此,如果您的图表无法构建,即无法执行任何计算,您将永远不会看到tf.Print.

在构建时,您只能看到张量的静态形状(例如,使用 Python 原生打印语句打印它们)。我不知道在构建时获取动态形状的任何方法(动态形状取决于您提供的实际输入,因此在您实际提供某些东西之前无法知道这一点,这只发生在运行时)。对于我的目的而言,了解静态形状通常就足够了。如果您不是这种情况,请尝试在玩具示例中将动态尺寸设为静态,然后 Python 打印所有形状以追踪问题。


推荐阅读