首页 > 解决方案 > 使用 tf.map_fn 计算张量的逆

问题描述

使用 Tensorflow 1.4 我想使用映射函数计算张量的元素逆(x --> 1/x)。如果张量中某个元素的值为零,我希望输出为零。

作为 的示例tensor: [[0, 1, 0], [0.5, 0.5, 0.3]],我想要output: [[0,1,0], [2, 2, 3.333]]. 我知道我可以使用 tf.math.reciprocal_no_nan() intf2.0tf.math.divide_no_nan()in轻松获得所需的输出tf 1.4,但我想知道为什么以下代码不起作用:

tensor = tf.constant([[0, 1, 0], [0.5, 0.5, 0.3]], tf.float32)
tensor_inverse = tf.map_fn(lambda x: tf.cond(tf.math.not_equal(x, 0.0), lambda x: 1/x, lambda: 0) , tensor)

我收到此错误:

ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

标签: pythontensorflowlambdamap-functioninverse

解决方案


让我们分解您的代码示例:

您使用的第一个功能是map_fn. Map_fn将在第一个维度上拆分张量并将这些单独的张量传递给它的内部提供的函数。此功能不会给您带来任何问题。接下来是tf.cond. tf.cond期望它的谓词中有一个标量值。分解:

tensor = tf.constant([[0, 1, 0], [0.5, 0.5, 0.3]], tf.float32)
cond_val1 = tf.math.not_equal(tensor, 0.0)
print(cond_val1.shape)  # Shape (2, 3)

cond_val1在上面的例子中显然是一个张量。您将不得不使用tf.reduce_alltf.reduce_any将其转换为标量。然后,您将获得所需的标量tf.cond。例如:

cond_val2 = tf.reduce_all(tf.math.not_equal(tensor, 0.0))
print(cond_val2.shape)  # Shape ()

现在这将tf.cond起作用。但是你遇到了另一个问题。你已经失去了处理张量元素的能力。

其次,通过map_fn您在第一维传递整个张量拆分,在您的情况下将是[0, 1, 0]and [0.5, 0.5, 0.3]。但问题是你tf.condtrue_fn,并且false_fn没有能力按元素处理它。

希望您了解代码中存在的各种问题。


推荐阅读