python - 如何将两个 keras 模型连接到一个模型中?
问题描述
假设我有一个 ResNet50 模型,我希望将该模型的输出层连接到 VGG 模型的输入层。
这是 ResNet 模型和 ResNet50 的输出张量:
img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)
print(resnet50_model.output.shape)
我得到输出:
TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])
现在我想要一个新层,我将这个输出张量重塑为 (64,64,18)
然后我有一个 VGG16 模型:
VGG_model = VGG_model = VGG16(include_top=False, weights=None)
我希望 ResNet50 的输出重塑为所需的张量,并作为 VGG 模型的输入。所以本质上我想连接两个模型。有人可以帮我这样做吗?谢谢!
解决方案
有多种方法可以做到这一点。这是使用 Sequential 模型 API 执行此操作的一种方法。
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16
model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))
model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))
VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
模型总结如下
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Model) (None, 6, 6, 2048) 23587712
_________________________________________________________________
reshape (Reshape) (None, 64, 64, 18) 0
_________________________________________________________________
Conv2d (Conv2D) (None, 62, 62, 3) 489
_________________________________________________________________
vgg16 (Model) multiple 14714688
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________
完整代码在这里。
推荐阅读
- angular - 首次加载 Angular 时重定向到授权端点
- python - Tensorflow 批量大小影响精度
- sql - SQL - 计算所有列中出现的次数
- node.js - 我的节点 js 程序没有得到我的端点,因此不能 POST 或 GET
- firebase - 如何将 Cloud Firestore 数据库集合下载到 JSON 或 CSV 文件中?
- python-3.x - 为什么调用后参数值会存储在我的函数中?
- javascript - json_encode() 向我的 REST API 添加斜杠
- json - 如何从 Jenkinsfile 读取 JSON 元素以进行多分支管道
- database - 向 DolphinDB 表中的列添加注释
- android - 为什么我的代码在 Firestore 中写入了 doc ID 的完整路径?