python - 根据 Tensorflow 中另一个向量中的元素计算向量中的值
问题描述
我有两个向量:时间和事件。如果一个事件为 1,则应将同一索引处的时间分配给func_for_event1
。否则,它转到func_for_event0
。
import tensorflow as tf
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
我应该如何在 Tensorflow 中实现这个逻辑?说tf.cond
或tf.where
?
解决方案
这正是tf.where()
它的用途。此代码(已测试):
import tensorflow as tf
import numpy as np
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
with tf.Session() as sess:
res = sess.run( result, feed_dict = {
time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
event : np.array( [0, 1, 1, 0, 1] )
} )
print ( res )
输出:
[2.2 5.2 2. 0.04999995 2.8]
如预期的。
推荐阅读
- bit-manipulation - 二进制数中最长连续 1 的长度
- java - git:将现有存储库拆分为子模块
- r - 将 .csv 文件加载到 RStudio 时出现问题。带引号的字符串中的 EOF
- perl - 在 Perl 中将文件的行读入并行哈希
- node.js - 为什么这个 Cloud Function 不会触发通知?
- go - 以字符串形式发送带有请求正文的 POST 请求
- mysql - 在mysql数据库中插入HTML作为文本
- javascript - 使用 JSON 数据为类别 x 轴设置 C3 子图范围?
- php - 自定义类未加载到控制器中
- jquery - 如何通过在请求中传递宽度参数来调整图像大小?