首页 > 解决方案 > Tensorflow tf.switch_case 不适用于 keras 输入

问题描述

我想使用 tf.switch_case 能够使用输入重定向网络不同分支的学习流,但是 tf.switch_case 不适用于 Keras.Tensor ...

import tensorflow as tf
from tensorflow.keras.layers import Input
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
t_input = Input(shape=(1,), name="t_input")
r = tf.switch_case(t_input, branch_fns={0: f1, 1: f2}, default=f3)
Traceback (most recent call last):
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-bd285228541c>", line 5, in <module>
    r = tf.switch_case(t_input, branch_fns={0: f1, 1: f2}, default=f3)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3616, in switch_case
    return _indexed_case_helper(branch_fns, default, branch_index, name)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3315, in _indexed_case_helper
    branch_fns, default, branch_index)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3249, in _indexed_case_verify_and_canonicalize_args
    type(branch_index)))
TypeError: branch_index must a Tensor, got <class 'tensorflow.python.keras.engine.keras_tensor.KerasTensor'>

标签: pythontensorflowmachine-learningkeras

解决方案


我不确定你为什么要在 Keras 模型中这样做。但是在使用这些时要小心。f1 f2例如,如果您在其中定义模型,我不确定这些(对于梯度流动)的测试效果如何。尽管如此,您可以执行以下操作。

  1. 将您的输入定义为int32类型。因为tf.switch_case期待int32
  2. 确保您定义batch_shape而不是shape,以便您可以索引输入张量以获取要传递给的标量值tf.switch_case
  3. tf.switch_caseLambda一层。
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda

def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)

t_input = Input(batch_shape=(1,), dtype='int32', name="t_input")

r = Lambda(lambda x: tf.switch_case(x[0], branch_fns={0: f1, 1: f2}, default=f3))(t_input)
model = tf.keras.models.Model(inputs=t_input, outputs=r)

使用它

print(model.predict([1]))

退货

31

推荐阅读