首页 > 解决方案 > 如何创建具有子模型条件评估的组合 tf.keras 模型

问题描述

我想创建多个tf.keras.Sequential模型的组合,以便在任何给定时间点只评估一个子模型。为了更好地解释,我创建了以下模型(模型的代码在本文末尾):

组合模型图

在该图中,模型、sequential和是基于 LSTM 的子模型,是简单的第五个子模型。最后一层根据输入数据中的值(从 中提取)决定五个并行路径中的哪一个将实际提供网络输出。其他四个值被丢弃。sequential_1sequential_2sequential_3label_0arbiterin_arb

当然,其他四个并行层(对结果没有贡献)的计算被浪费了。所以我的问题是:有没有办法在 TensorFlow 内部解决这个问题,例如使用某种条件图路由而不是并行执行?

模型示例代码:

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)

batch_size = 1

def gen_lstm(base_label, num_features, num_units):
    return tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1, num_features), batch_size=batch_size,
                                   name="input_{}".format(base_label)),
        tf.keras.layers.LSTM(num_units,
                             batch_input_shape=(batch_size, 1, num_features),
                             return_sequences=False, stateful=True,
                             name="lstm_{}".format(base_label)),
        tf.keras.layers.Dense(1, name="dense_{}".format(base_label)), # binary
        tf.keras.layers.Activation('sigmoid', name="activ_{}".format(base_label)), # binary
    ])

models = {}
for l in [1, 3, 4, 5]:
    global m
    m = gen_lstm(l, 130, 88)
    models[l] = m

in_all = tf.keras.layers.InputLayer(input_shape=(1, 132), batch_size=batch_size, name="input_all")
in_lstm = tf.keras.layers.Lambda(lambda x: tf.slice(x, [0, 0, 2], [-1, -1, -1]), name="in_lstm")(in_all.output)

def model_0(x):
    return x[:, :, 1]

out_model_0 = tf.keras.layers.Lambda(model_0, name="label_0")(in_all.output)

out_concat = tf.keras.layers.Concatenate(axis=1, name="concat_infer")([out_model_0] + [m(in_lstm) for m in models.values()])

in_arb = tf.keras.layers.Lambda(lambda x: tf.reshape(tf.slice(x, [0, 0, 0], [-1, -1, 1]), (batch_size, 1)), name="in_arb")(in_all.output)
out_merged = tf.keras.layers.Concatenate(axis=1, name="concat_arb")([in_arb, out_concat])

def arbiter(x):
    return tf.where(tf.equal(x[:, 0], tf.constant(0.0, dtype=tf.float32)), x[:, 1], tf.where(
        tf.equal(x[:, 0], tf.constant(2.0, dtype=tf.float32)), x[:, 2] + tf.constant(2.0), tf.where(
            tf.equal(x[:, 0], tf.constant(3.0, dtype=tf.float32)), x[:, 3] + tf.constant(3.0), tf.where(
                tf.equal(x[:, 0], tf.constant(4.0, dtype=tf.float32)), x[:, 4] + tf.constant(4.0),
                x[:, 5] + tf.constant(5.0)))))

out_merged = tf.keras.layers.Lambda(arbiter, name="arbiter")(out_merged)

lstm_model = tf.keras.Model([in_all.input], out_merged)

print(lstm_model.summary())
tf.keras.utils.plot_model(lstm_model, to_file="./temp.png", show_shapes=True)

标签: tensorflowkerastf.keras

解决方案


推荐阅读