首页 > 解决方案 > 广播数组上的 tf.where()

问题描述

我有两个数组(x 是 1D,y 是 2D)。我计算了数组“diff”,它基本上是广播差异(xy[:,None])。我想用一个很大的值(比如 10000)替换数组“diff”中的所有零。这个操作在 numpy 中是微不足道的,如下所示:

x=np.array([1.0,1.0,1.0])
y=np.array([[1.0,1.0,1.0],[0.0,0.0,0.0]])
diff = x - y[:, None]
diff = np.where(diff==0.0,10000,diff)

但是,我无法在 Tensorflow 中重现相同的行为。我尝试了以下代码块。

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32)
diff = x - y[:,None]
diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)

diff = tf.where(diff==diff_zero, diff_big, diff)

sess = tf.Session()
diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})

任何解决方法将不胜感激。

标签: pythonnumpytensorflow

解决方案


我想出了怎么做。我不得不使用 tf.equal() 而不是“==”。以下代码行就像 numpy 一样完成了这项工作。

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32)
diff = x - y[:,None]

diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)

condition = tf.equal(diff_zero, diff)
diff = tf.where(condition, diff_big, diff)
sess = tf.Session()

diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})

推荐阅读