首页 > 解决方案 > tf.where() 在操纵张量时的行为不符合预期

问题描述

我尝试了以下代码:

a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))

不会产生与以下相同的结果:

a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)

这里 x 是 0 或 1 的二进制变量。b 是介于 0 和 1 之间的实数值。

有什么我想念的吗?
我比较两个答案的方式是 tf.reduce_sum(a)

找到的解决方案:2 确实等效于 x = 0 或 x = 1。我使用的数据是 2D 张量,其中有一些不是 0 或 1 的位。这是通过发现的 tf.unique(tf.reshape(x, (-1,))

标签: tensorflowmultiplexing

解决方案


代码示例:

# When x = 0.0 

x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472

# When x = 1.0 

x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472


# When x = 0.4 

x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) +  (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.

您问题中提到的两个代码会产生相同结果的唯一情况是 x = 1 或 0。

tf.reduce_sum(a)

这里,a 是一个标量,所以这不会改变 a 的值。


推荐阅读