首页 > 解决方案 > 用于堆叠两个 CNN 的自适应模块设计

问题描述

我正在尝试使用自适应模块堆叠两个不同的CNN来桥接它们,但我很难正确确定自适应模块的层超参数。

更准确地说,我想训练适配模块来桥接两个卷积层:

  1. 具有输出形状的 A 层:(29,29,256)
  2. 具有输入形状的 B 层:(8,8,384)

因此,在 A 层之后,我依次添加了适配模块,为此我选择了:

最后,我尝试将 B 层添加到模型中,但我从 tensorflow中得到以下错误:

InvalidArgumentError: Dimensions must be equal, but are 384 and 288 for '{{node batch_normalization_159/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization_159/scale, batch_normalization_159/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp, batch_normalization_159/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [?,8,8,384], [288], [288], [288], [288].

有一个最小的可重现示例:

from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.mobilenet import MobileNet
from keras.layers import Conv2D, MaxPool2D
from keras.models import Sequential

mobile_model = MobileNet(weights='imagenet')
server_model = InceptionResNetV2(weights='imagenet')

hybrid = Sequential()

for i, layer in enumerate(mobile_model.layers):
  if i <= 36:
    layer.trainable = False
    hybrid.add(layer)

hybrid.add(Conv2D(384, kernel_size=(3,3), padding='same'))
hybrid.add(MaxPool2D(pool_size=(2,2), strides=(4,4), padding='same'))

for i, layer in enumerate(server_model.layers):
  if i >= 610:
    layer.trainable = False
    hybrid.add(layer)

标签: pythontensorflowkerasneural-networkconv-neural-network

解决方案


顺序模型只支持层像链表一样排列的模型——每一层只接受一层的输出,每一层的输出只馈送到单层。您的两个基本模型具有残差块,这打破了上述假设,并将模型架构转变为有向无环图 (DAG)。

要做你想做的事,你需要使用功能 API。使用功能 API,您可以显式控制中间激活,即 KerasTensors。

对于第一个模型,您可以跳过额外的工作,只需像这样从现有图的子集创建一个新模型

sub_mobile = keras.models.Model(mobile_model.inputs, mobile_model.layers[36].output)

连接第二个模型的某些层要困难得多。切掉 keras 模型的结尾很容易 - 由于需要 tf.keras.Input 占位符,因此切开开头要困难得多。要成功地做到这一点,您需要编写一个模型遍历算法,该算法遍历各个层,跟踪输出 KerasTensor,然后使用新输入调用每个层以创建新的输出 KerasTensor。

您可以通过简单地为 InceptionResNet 找到一些源代码并通过 Python 添加层而不是内省现有模型来避免所有这些工作。这是一个可能符合要求的。

https://github.com/yuyang-huang/keras-inception-resnet-v2/blob/master/inception_resnet_v2.py


推荐阅读