python - Tensorflow tf.switch_case 不适用于 keras 输入
问题描述
我想使用 tf.switch_case 能够使用输入重定向网络不同分支的学习流,但是 tf.switch_case 不适用于 Keras.Tensor ...
import tensorflow as tf
from tensorflow.keras.layers import Input
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
t_input = Input(shape=(1,), name="t_input")
r = tf.switch_case(t_input, branch_fns={0: f1, 1: f2}, default=f3)
Traceback (most recent call last):
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-9-bd285228541c>", line 5, in <module>
r = tf.switch_case(t_input, branch_fns={0: f1, 1: f2}, default=f3)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3616, in switch_case
return _indexed_case_helper(branch_fns, default, branch_index, name)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3315, in _indexed_case_helper
branch_fns, default, branch_index)
File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3249, in _indexed_case_verify_and_canonicalize_args
type(branch_index)))
TypeError: branch_index must a Tensor, got <class 'tensorflow.python.keras.engine.keras_tensor.KerasTensor'>
解决方案
我不确定你为什么要在 Keras 模型中这样做。但是在使用这些时要小心。f1
f2
例如,如果您在其中定义模型,我不确定这些(对于梯度流动)的测试效果如何。尽管如此,您可以执行以下操作。
- 将您的输入定义为
int32
类型。因为tf.switch_case
期待int32
。 - 确保您定义
batch_shape
而不是shape
,以便您可以索引输入张量以获取要传递给的标量值tf.switch_case
。 - 包
tf.switch_case
在Lambda
一层。
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
t_input = Input(batch_shape=(1,), dtype='int32', name="t_input")
r = Lambda(lambda x: tf.switch_case(x[0], branch_fns={0: f1, 1: f2}, default=f3))(t_input)
model = tf.keras.models.Model(inputs=t_input, outputs=r)
使用它
print(model.predict([1]))
退货
31
推荐阅读
- sql - SSIS 在我的包的每次执行中创建一个 csv 文件
- linux - 如何在 ssh_config 的 Azure CLI bash 窗口中将 StrictHostKeyChecking 设置为 no
- java - Java:如何检查哈希图中是否存在键
- revolution-slider - 移动设备上的向下滚动问题(Revolution Slider 6)
- deployment - 如何导出电子应用程序以在浏览器中使用
- python - 使用 Jupyter 时终端出现错误消息
- react-native - ScrollView 子项相对于容器的高度
- html - 如何调试为什么 HTML 元素具有某些属性
- sql - SQLITE 约束显示为新列
- algorithm - 为什么 hIndex 是 3 而不是 0,因为有 5 篇论文的引用次数为 0?