首页 > 解决方案 > 用 if-else 组合两个网络

问题描述

我有两个网络ModelUpModelDown它们采用相同的输入x1, x2,其中x110 个特征x2是每个样本的单个数字。我想要一个combinedModel网络:

if x2>=1:
    return ModelUp([x1,x2])
else:
    return ModelDown([x1,x2])

应该不需要培训combinedModel一次ModelUp,并且ModelDown已经单独接受过培训。

我怎样才能在tensorflow.keras(tensorflow版本是1.12.0)中进行这种组合?

标签: tensorflowkeras

解决方案


以下代码将是一个选项。您需要检查它是否适用于您的用例。

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

def combineModels(ModelUp, ModelDown):
    input1 = Input(shape=(10,))
    input2 = Input(shape=(1,))
    # 
    selectModel1 = K.cast(K.greater_equal(input2, 1), dtype='float32')
    selectModel2 = K.cast(K.less(input2, 1), dtype='float32')
    # 
    out   = ModelUp([input1,input2]) * selectModel1 + ModelDown([input1,input2]) * selectModel2
    model = Model(inputs=[input1,input2], outputs=out)
    return model

combinedModel = combineModels(ModelUp, ModelDown)

对于张量流 1:

from tensorflow.keras.layers import Input, Lambda, Multiply, Add
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

def combineModels(ModelUp, ModelDown):
    input1       = Input(shape=(10,))
    input2       = Input(shape=(1,))
    # 
    selectModel1 = Lambda(lambda x: K.greater_equal(x, K.constant(1.)))    (input2)
    selectModel2 = Lambda(lambda x: K.less(x, K.constant(1.)))(input2)
    # 
    selectModel1 = Lambda(lambda x: K.cast(x, dtype='float32'))(selectModel1)
    selectModel2 = Lambda(lambda x: K.cast(x, dtype='float32'))(selectModel2)
    # 
    out1 = Multiply()([ModelUp([input1,input2]), selectModel1])
    out2 = Multiply()([ModelDown([input1,input2]), selectModel2])
    out  = Add()([out1, out2])
    model = Model(inputs=[input1,input2], outputs=out)
    return model

combinedModel = combineModels(ModelUp, ModelDown)

推荐阅读