首页 > 解决方案 > 如何在 TensorFlow 图中编写 if 语句?

问题描述

在我要构建的模型中,有两个占位符

x = tf.placeholder('float32', shape=[1000, 10])
tags = tf.placeholder('int32', shape=[1000, 1])

(1000 只是示例的数量)

x保存神经网络的输入,tags确定三个神经网络中的哪一个将用于计算输出。

w1 = tf.get_variable('w1', [10, 1], tf.truncated_normal_initializer())
w2 = tf.get_variable('w2', [10, 1], tf.truncated_normal_initializer())
w3 = tf.get_variable('w3', [10, 1], tf.truncated_normal_initializer())

def nn_1(): return tf.matmul(x, w1)
def nn_2(): return tf.matmul(x, w2)
def nn_3(): return tf.matmul(x, w3)

我想找到一种优雅的方式来实现一个 TensorFlow 图,它​​可以计算x给定其tag.

[x1, x2, x3, ..., xn]
[1,  2,  3,  ..., 1]
[nn_1(x), nn_2(x), nn_3(x), ..., nn_1(x)]

如果xandtags不是数组,我可以用 来实现它tf.case,例如,

a = tf.placeholder('int32')
b = tf.placeholder('int32')
result = tf.case(
    {
        tf.equal(b, 1): a + 1, 
        tf.equal(b, 2): a + 2
    })

但我不知道何时xtags是数组。

标签: pythontensorflow

解决方案


您可以使用一些数学技巧来完成这项工作。

假设您要实现您实现的代码,但 a 和 b 是数组。

首先,您计算一个条件数组。这将是应用该操作必须为真的条件。通常条件使用“更少”、“相等”、“更大”操作或这些操作的逻辑组合。您可以使用tf.bitwiseortf.math.logical*进行逻辑操作和tf.math其他操作。条件必须是一个布尔数组。如果条件为真,则为 1,如果为假,则为 0。

之后,使用默认值初始化结果数组(“else”语句中的内容)

要应用条件,您只需将条件数组与要分配的值相乘。

代码将是这样的。

//default value
result = tf.zeros(tf.shape(a)[0])
condition = tf.equal(b, index)
condition = tf.cast(condition, tf.float32)
result = tf.multiply(condition, a) + index

如果要使用tag函数index数组,则需要使用二维数组。创建一个包含所有可能组合 nn X x 的二维数组。该数组将包含每个 i,j 对的 nn_j(x[i])。

为此,您需要创建一个数组 x X nn X 2 数组。首先,展开 x 并创建一个包含 x X nn 数组的数组

如果你的 x 是 x=[0,2,1] 并且 len(n) = 2 那么你需要 x_nn = [[0,0], [2,2], [1,1]]。

nn_x = x
nn_x = tf.expand_dims(nn_x,0)
nn_x = tf.tile(nn_x, [len(nn), 1])

然后,您创建一个具有相同形状且索引为 nn 的二维数组。对于 arrey 使用的早期 index2d = [[0,1],[0,1],[0,1]]

index = tf.linspace(0,len(nn)-1)
index2d = tf.expand_dims(index,0)
index2d = tf.tile(index2d, tf.shape(x)[0])

然后你需要堆叠这些和数组,将第一个维度移动到最后一个位置,然后沿轴 0 和 1 平展。

这样,您将拥有 map2d = [[0,0],[0,1],[2,0],[2,1],[1,0],[1,1]] 对于每对第一对是 x 的值,第二个是 nn 的索引

然后使用该tf.map_fn函数映射这个二维数组。写类似的东西

tf.map_fn(t => [nn[t[1]](t[0]), t[1]], map2d)

现在每个 x 都有 nn 的所有可能值

此时,您可以重新调整 map2d,将 map2d[:,:,1] 与您的标签进行比较,然后选择相等的那个。

#reshape map2d
# ...
# transform tag
tag2d = tag
tag2d = tf.expand_dims(tag2d,0)
tag2d = tf.tile(tag2d, [len(nn), 1])
result = tf.equal(tag, map2d[:,:,1])

结果每列只有一个非零值

结果 = tf.multiply(result, map2d[:,:,0]) 结果 = tf.reduce_max(result, [1])

我没有尝试代码,但该机制应该可以工作。

希望这有帮助


推荐阅读