python - 用 tensorflow.keras 连接两个模型
问题描述
我目前正在使用 MNIST 数据集研究用于图像分析的神经网络模型。我首先只使用图像来构建第一个模型。然后我创建了一个附加变量,即:当数字实际上在 0 和 4 之间时为 0,当它大于或等于 5 时为 1。
因此,我想建立一个可以获取这两个信息的模型:数字的图像,以及我刚刚创建的附加变量。
我创建了两个第一个模型,一个用于图像,一个用于外生变量,如下所示:
import tensorflow as tf
from tensorflow import keras
image_model = keras.models.Sequential()
#First conv layer :
image_model.add( keras.layers.Conv2D( 64, kernel_size=3,
activation=keras.activations.relu,
input_shape=(28, 28, 1) ) )
#Second conv layer :
image_model.add( keras.layers.Conv2D( 32, kernel_size=3, activation=keras.activations.relu ) )
#Flatten layer :
image_model.add( keras.layers.Flatten() )
print( image_model.summary(), '\n' )
info_model = keras.models.Sequential()
info_model.add( keras.layers.Dense( 5, activation=keras.activations.relu, input_shape=(1,) ) )
print( info_model.summary() )
然后我想连接两个最终层,最后用 softmax 放置另一个密集层来预测类概率。
我知道使用 Keras 函数 API 是可行的,但是如何使用 tf.keras 来做到这一点?
解决方案
您可以在 TF 中轻松使用 Keras 的功能 API(使用 TF 2.0 测试):
import tensorflow as tf
# Image
input_1 = tf.keras.layers.Input(shape=(28, 28, 1))
conv2d_1 = tf.keras.layers.Conv2D(64, kernel_size=3,
activation=tf.keras.activations.relu)(input_1)
# Second conv layer :
conv2d_2 = tf.keras.layers.Conv2D(32, kernel_size=3,
activation=tf.keras.activations.relu)(conv2d_1)
# Flatten layer :
flatten = tf.keras.layers.Flatten()(conv2d_2)
# The other input
input_2 = tf.keras.layers.Input(shape=(1,))
dense_2 = tf.keras.layers.Dense(5, activation=tf.keras.activations.relu)(input_2)
# Concatenate
concat = tf.keras.layers.Concatenate()([flatten, dense_2])
n_classes = 4
# output layer
output = tf.keras.layers.Dense(units=n_classes,
activation=tf.keras.activations.softmax)(concat)
full_model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output])
print(full_model.summary())
这为您提供了您正在寻找的模型。
推荐阅读
- jquery - 我无法执行 ajax 请求并获得 json 响应
- vba - 如何使用正确的驱动程序为 w10 64 位 ms-access odbc 创建用户 dsn
- rust - 如何在不使用克隆的情况下迭代向量并比较元素以防止借用错误
- javascript - 动态创建对象数组错误修复
- react-native - 在 DrawerNavigator 中通过屏幕选项传递和获取道具
- php - 一个复选框 两个具有特定用途的值
- ejb - JSR 345 的文档来源(EJB 3.2 Final)
- python - Python和egrep中正则表达式括号的区别
- javascript - 如何让我的 Discord 机器人加入特定服务器?该机器人需要付费订阅,我不希望邀请链接被泄露
- python - Django CSRF 导致服务器在错误的 URL 请求后出错