首页 > 解决方案 > 根据 Tensorflow 中另一个向量中的元素计算向量中的值

问题描述

我有两个向量:时间和事件。如果一个事件为 1,则应将同一索引处的时间分配给func_for_event1。否则,它转到func_for_event0

import tensorflow as tf

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

我应该如何在 Tensorflow 中实现这个逻辑?说tf.condtf.where

标签: pythontensorflow

解决方案


这正是tf.where()它的用途。此代码(已测试):

import tensorflow as tf
import numpy as np

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

with tf.Session() as sess:
    res = sess.run( result, feed_dict = {
        time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
        event : np.array( [0, 1, 1, 0, 1] )
    } )
    print ( res )

输出:

[2.2 5.2 2. 0.04999995 2.8]

如预期的。


推荐阅读