tensorflow - 用 if-else 组合两个网络
问题描述
我有两个网络ModelUp
,ModelDown
它们采用相同的输入x1, x2
,其中x1
10 个特征x2
是每个样本的单个数字。我想要一个combinedModel
网络:
if x2>=1:
return ModelUp([x1,x2])
else:
return ModelDown([x1,x2])
应该不需要培训combinedModel
一次ModelUp
,并且ModelDown
已经单独接受过培训。
我怎样才能在tensorflow.keras
(tensorflow版本是1.12.0
)中进行这种组合?
解决方案
以下代码将是一个选项。您需要检查它是否适用于您的用例。
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
def combineModels(ModelUp, ModelDown):
input1 = Input(shape=(10,))
input2 = Input(shape=(1,))
#
selectModel1 = K.cast(K.greater_equal(input2, 1), dtype='float32')
selectModel2 = K.cast(K.less(input2, 1), dtype='float32')
#
out = ModelUp([input1,input2]) * selectModel1 + ModelDown([input1,input2]) * selectModel2
model = Model(inputs=[input1,input2], outputs=out)
return model
combinedModel = combineModels(ModelUp, ModelDown)
对于张量流 1:
from tensorflow.keras.layers import Input, Lambda, Multiply, Add
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
def combineModels(ModelUp, ModelDown):
input1 = Input(shape=(10,))
input2 = Input(shape=(1,))
#
selectModel1 = Lambda(lambda x: K.greater_equal(x, K.constant(1.))) (input2)
selectModel2 = Lambda(lambda x: K.less(x, K.constant(1.)))(input2)
#
selectModel1 = Lambda(lambda x: K.cast(x, dtype='float32'))(selectModel1)
selectModel2 = Lambda(lambda x: K.cast(x, dtype='float32'))(selectModel2)
#
out1 = Multiply()([ModelUp([input1,input2]), selectModel1])
out2 = Multiply()([ModelDown([input1,input2]), selectModel2])
out = Add()([out1, out2])
model = Model(inputs=[input1,input2], outputs=out)
return model
combinedModel = combineModels(ModelUp, ModelDown)
推荐阅读
- javascript - 如何在 REACT js 中正确地从 JSON 对象中删除特定项目?
- delphi - Rad Studio 10.2.3:找不到驱动程序/连接注册表文件 dbxconnections.ini
- python - 将 HTML 中的段落文本格式化为单行
- c++ - Ifstream不使用C++中的数组检索数据
- php - 如何从另一个表中选择表中的特定条件
- amazon-web-services - EC2 Cloudformation 模板的 Cloud-init 中的用户数据脚本错误
- javascript - 如何以对象的形式向服务器发送多个值
- amazon-web-services - 如何使用MWS服务为多个亚马逊卖家开发应用程序?
- r - 制作直方图
- python - 如何在二进制文件中获得 readelf/IDA 和 Aho-Corasick 之间的相同偏移量