tensorflow - 如何创建具有子模型条件评估的组合 tf.keras 模型
问题描述
我想创建多个tf.keras.Sequential
模型的组合,以便在任何给定时间点只评估一个子模型。为了更好地解释,我创建了以下模型(模型的代码在本文末尾):
在该图中,模型、sequential
和是基于 LSTM 的子模型,是简单的第五个子模型。最后一层根据输入数据中的值(从 中提取)决定五个并行路径中的哪一个将实际提供网络输出。其他四个值被丢弃。sequential_1
sequential_2
sequential_3
label_0
arbiter
in_arb
当然,其他四个并行层(对结果没有贡献)的计算被浪费了。所以我的问题是:有没有办法在 TensorFlow 内部解决这个问题,例如使用某种条件图路由而不是并行执行?
模型示例代码:
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
batch_size = 1
def gen_lstm(base_label, num_features, num_units):
return tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1, num_features), batch_size=batch_size,
name="input_{}".format(base_label)),
tf.keras.layers.LSTM(num_units,
batch_input_shape=(batch_size, 1, num_features),
return_sequences=False, stateful=True,
name="lstm_{}".format(base_label)),
tf.keras.layers.Dense(1, name="dense_{}".format(base_label)), # binary
tf.keras.layers.Activation('sigmoid', name="activ_{}".format(base_label)), # binary
])
models = {}
for l in [1, 3, 4, 5]:
global m
m = gen_lstm(l, 130, 88)
models[l] = m
in_all = tf.keras.layers.InputLayer(input_shape=(1, 132), batch_size=batch_size, name="input_all")
in_lstm = tf.keras.layers.Lambda(lambda x: tf.slice(x, [0, 0, 2], [-1, -1, -1]), name="in_lstm")(in_all.output)
def model_0(x):
return x[:, :, 1]
out_model_0 = tf.keras.layers.Lambda(model_0, name="label_0")(in_all.output)
out_concat = tf.keras.layers.Concatenate(axis=1, name="concat_infer")([out_model_0] + [m(in_lstm) for m in models.values()])
in_arb = tf.keras.layers.Lambda(lambda x: tf.reshape(tf.slice(x, [0, 0, 0], [-1, -1, 1]), (batch_size, 1)), name="in_arb")(in_all.output)
out_merged = tf.keras.layers.Concatenate(axis=1, name="concat_arb")([in_arb, out_concat])
def arbiter(x):
return tf.where(tf.equal(x[:, 0], tf.constant(0.0, dtype=tf.float32)), x[:, 1], tf.where(
tf.equal(x[:, 0], tf.constant(2.0, dtype=tf.float32)), x[:, 2] + tf.constant(2.0), tf.where(
tf.equal(x[:, 0], tf.constant(3.0, dtype=tf.float32)), x[:, 3] + tf.constant(3.0), tf.where(
tf.equal(x[:, 0], tf.constant(4.0, dtype=tf.float32)), x[:, 4] + tf.constant(4.0),
x[:, 5] + tf.constant(5.0)))))
out_merged = tf.keras.layers.Lambda(arbiter, name="arbiter")(out_merged)
lstm_model = tf.keras.Model([in_all.input], out_merged)
print(lstm_model.summary())
tf.keras.utils.plot_model(lstm_model, to_file="./temp.png", show_shapes=True)
解决方案
推荐阅读
- node.js - 错误缺少绑定 app/nodes_modules/node-sass
- mysql - GROUP BY MySql 的替代方案
- python - 如何在 seaborn catplot 上插入一条恒定的水平线
- c - 我如何描述C中二维数组的每一列?
- python - RandomForestClassifier:GridSearchCV 后召回率低
- python - 如何在 numpy 结构化数组中选择一行并设置一个值?
- c - 如何正确打印/排序使用 fread 读取到数组中的二进制文件数据
- computer-vision - 立体视觉 3D 重建:关于图像校正的说明
- python - 使用 matplotlib 或 Seaborn 在 Python 中使用日期时间索引绘制组条形图
- javascript - 如何在显示消息期间添加按钮加载:Vue.js