首页 > 解决方案 > 在 keras 中跳过预训练模型的连接

问题描述

所以我目前正在以下论文中实现模型https://openaccess.thecvf.com/content_cvpr_2018/papers/Oh_Fast_Video_Object_CVPR_2018_paper.pdf

正如下面的模型所示,他们在模型中使用了 2 个 resnet50 模型的图像 标记为连体编码器

我使用了 Keras 提供的 resnet50 模型,代码如下:

input_shape = (480,854,4)
inputlayer_Q = Input(shape=input_shape, name="inputlayer_Q")
convlayer_Q = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_Q)
model_Q = tf.keras.applications.resnet50.ResNet50(
input_shape=(
  convlayer_Q.shape[1],convlayer_Q.shape[2],convlayer_Q.shape[3]),
  include_top=False,
  weights='imagenet'
 )

然后他们从 resnet 模型内部的层中获取了 3 个跳过连接,我尝试使用以下行来获取跳过连接

res2_skip = model_Q.layers[38].output
res3_skip = model_Q.layers[80].output
res4_skip = model_Q.layers[142].output

但是当我稍后在模型中使用它并尝试运行它时,会给我 Graph 断开连接。

那么有什么方法可以跳过 Keras 提供的连接/修改模型?

标签: pythontensorflowkeras

解决方案


尝试这个:

input_shape = (480,854,4)


# Target Stream  = Q
inputlayer_Q = Input(shape=input_shape, name="inputlayer_Q")
# Refrence Stream = M
inputlayer_M = Input(shape=input_shape,name="inputlayer_M")


convlayer_Q = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_Q)
convlayer_M = Conv2D(filters= 3,kernel_size = (3,3),padding = 'same')(inputlayer_M)

model_Q = tf.keras.applications.resnet50.ResNet50(
    input_shape=(convlayer_Q.shape[1],convlayer_Q.shape[2],convlayer_Q.shape[3]), include_top=False, weights='imagenet' 
)
model_Q._name ="resnet50_Q"

model_M = tf.keras.applications.resnet50.ResNet50(
    input_shape=(convlayer_M.shape[1],convlayer_M.shape[2],convlayer_M.shape[3]), include_top=False, weights='imagenet' 
)
model_M._name ="resnet50_M"

for model in [model_Q, model_M]:
  for layer in model.layers:
    old_name = layer.name
    layer._name = f"{model.name}_{old_name}"
    print(layer._name)
    

encoder_Q = tf.keras.Model(inputs=model_Q.inputs, outputs=model_Q.output,name ="encoder_Q" )
encoder_M = tf.keras.Model(inputs=model_M.inputs, outputs=model_M.output,name ="encoder_M" )


concatenate = Concatenate(axis=0,name ="Concatenate")([encoder_Q.output, encoder_M.output])
global_layer = GlobalConvBlock(concatenate)

res2_skip = encoder_Q.layers[38].output
res2_skip = ZeroPadding2D(padding=(0,1), data_format=None)(res2_skip)
res3_skip = encoder_Q.layers[80].output
res3_skip = ZeroPadding2D(padding=((0,0),(0,1)), data_format=None)(res3_skip)
res4_skip = encoder_Q.layers[142].output


ref1_16 = refineblock(res4_skip,global_layer,"ref1_16")
ref1_8 = refineblock(res3_skip,ref1_16,"ref1_8")
ref1_4 = refineblock(res2_skip,ref1_8,"ref1_4")
outconv = Conv2D(filters= 2,kernel_size = (3,3)) (ref1_4)
outconv1 = ZeroPadding2D(padding=((1,1),(0,0)), data_format=None)(outconv)
output = Softmax()(outconv1)

main_model = tf.keras.Model(inputs=[encoder_Q.inputs, encoder_M.inputs],outputs=output, name ="main model" )

推荐阅读