python - 广播数组上的 tf.where()
问题描述
我有两个数组(x 是 1D,y 是 2D)。我计算了数组“diff”,它基本上是广播差异(xy[:,None])。我想用一个很大的值(比如 10000)替换数组“diff”中的所有零。这个操作在 numpy 中是微不足道的,如下所示:
x=np.array([1.0,1.0,1.0])
y=np.array([[1.0,1.0,1.0],[0.0,0.0,0.0]])
diff = x - y[:, None]
diff = np.where(diff==0.0,10000,diff)
但是,我无法在 Tensorflow 中重现相同的行为。我尝试了以下代码块。
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
diff = x - y[:,None]
diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)
diff = tf.where(diff==diff_zero, diff_big, diff)
sess = tf.Session()
diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})
任何解决方法将不胜感激。
解决方案
我想出了怎么做。我不得不使用 tf.equal() 而不是“==”。以下代码行就像 numpy 一样完成了这项工作。
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
diff = x - y[:,None]
diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)
condition = tf.equal(diff_zero, diff)
diff = tf.where(condition, diff_big, diff)
sess = tf.Session()
diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})
推荐阅读
- c++ - 为什么`std::this_thread::sleep_for` 函数拒绝在 WSL Ubuntu 上休眠?
- docker - 在docker中,“减少build-index组件的缓存间隔”是什么意思
- python - 如何使用 sys.stdin 在 python 中只读取第一个空行
- java - java中表达式的别名
- sql - 如何比较一个数据库表的列与另一列的相似性
- c# - 我需要将此 powershell 代码转换为 C#,因为我无法设置执行策略来运行 powershell 脚本
- c# - 如何在 C# 中将字符串与字符串数组进行比较?
- nestjs - 如何在 NestJS 拦截器中获取存储库
- windows - 如何在 Windows 10 上安装 patchutils
- javascript - 如何使用 SQLite3 和 node.js 从异步函数返回值?