python - 在 keras 中跳过预训练模型的连接
问题描述
所以我目前正在以下论文中实现模型https://openaccess.thecvf.com/content_cvpr_2018/papers/Oh_Fast_Video_Object_CVPR_2018_paper.pdf
正如下面的模型所示,他们在模型中使用了 2 个 resnet50 模型的图像 标记为连体编码器
我使用了 Keras 提供的 resnet50 模型,代码如下:
input_shape = (480,854,4)
inputlayer_Q = Input(shape=input_shape, name="inputlayer_Q")
convlayer_Q = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_Q)
model_Q = tf.keras.applications.resnet50.ResNet50(
input_shape=(
convlayer_Q.shape[1],convlayer_Q.shape[2],convlayer_Q.shape[3]),
include_top=False,
weights='imagenet'
)
然后他们从 resnet 模型内部的层中获取了 3 个跳过连接,我尝试使用以下行来获取跳过连接
res2_skip = model_Q.layers[38].output
res3_skip = model_Q.layers[80].output
res4_skip = model_Q.layers[142].output
但是当我稍后在模型中使用它并尝试运行它时,会给我 Graph 断开连接。
那么有什么方法可以跳过 Keras 提供的连接/修改模型?
解决方案
尝试这个:
input_shape = (480,854,4)
# Target Stream = Q
inputlayer_Q = Input(shape=input_shape, name="inputlayer_Q")
# Refrence Stream = M
inputlayer_M = Input(shape=input_shape,name="inputlayer_M")
convlayer_Q = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_Q)
convlayer_M = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_M)
model_Q = tf.keras.applications.resnet50.ResNet50(
input_shape=(convlayer_Q.shape[1],convlayer_Q.shape[2],convlayer_Q.shape[3]), include_top=False, weights='imagenet'
)
model_Q._name ="resnet50_Q"
model_M = tf.keras.applications.resnet50.ResNet50(
input_shape=(convlayer_M.shape[1],convlayer_M.shape[2],convlayer_M.shape[3]), include_top=False, weights='imagenet'
)
model_M._name ="resnet50_M"
for model in [model_Q, model_M]:
for layer in model.layers:
old_name = layer.name
layer._name = f"{model.name}_{old_name}"
print(layer._name)
encoder_Q = tf.keras.Model(inputs=model_Q.inputs, outputs=model_Q.output,name ="encoder_Q" )
encoder_M = tf.keras.Model(inputs=model_M.inputs, outputs=model_M.output,name ="encoder_M" )
concatenate = Concatenate(axis=0,name ="Concatenate")([encoder_Q.output, encoder_M.output])
global_layer = GlobalConvBlock(concatenate)
res2_skip = encoder_Q.layers[38].output
res2_skip = ZeroPadding2D(padding=(0,1), data_format=None)(res2_skip)
res3_skip = encoder_Q.layers[80].output
res3_skip = ZeroPadding2D(padding=((0,0),(0,1)), data_format=None)(res3_skip)
res4_skip = encoder_Q.layers[142].output
ref1_16 = refineblock(res4_skip,global_layer,"ref1_16")
ref1_8 = refineblock(res3_skip,ref1_16,"ref1_8")
ref1_4 = refineblock(res2_skip,ref1_8,"ref1_4")
outconv = Conv2D(filters= 2,kernel_size = (3,3)) (ref1_4)
outconv1 = ZeroPadding2D(padding=((1,1),(0,0)), data_format=None)(outconv)
output = Softmax()(outconv1)
main_model = tf.keras.Model(inputs=[encoder_Q.inputs, encoder_M.inputs],outputs=output, name ="main model" )
推荐阅读
- assembly - MIPS 初学者,为什么它不打印我的第三个提示?
- java - jconsole可以用来监控实时应用吗-prode中部署的应用
- php - 为什么 XDebug 会导致页面在 Visual Studio Code PHP 调试扩展中持续加载?
- postgresql - Flask 应用程序 AWS Postgres 连接在本地工作,但不在 Heroku 上
- react-native-popup-menu - 如何在自定义平面列表中实现这个库
- cytoscape.js - 图中的弹性(动画)边,节点对邻居施加引力?
- java - RecyclerView 最初不加载数据
- javascript - 循环遍历对象数组并创建一个真值数组
- typescript - 如何键入此合并功能?
- jquery - 从 div 容器中获取选定的数据以存储在 asp.net mvc 中的数据库中