首页 > 解决方案 > 用不同的通道替换预训练模型的输入层?

问题描述

我想重新使用 的预训练权重MobiletNetv2,但图像有12 个通道。我知道这需要增加重量,但这没关系,因为无论如何我都想重新训练。我找不到让它工作的方法。

import tensorflow as tf

class CNN(tf.keras.Model):
    def __init__(self):
        super(CNN, self).__init__()
        self.input_layer = tf.keras.layers.InputLayer(input_shape=(None, 224, 224, 12))
        self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                                      include_top=False,
                                                      weights='imagenet')
        _ = self.base._layers.pop(0)
        self.flat1 = tf.keras.layers.Flatten()
        self.dens3 = tf.keras.layers.Dense(10)

    def call(self, x, **kwargs):
        x = self.input_layer(x)
        x = self.base(x)
        x = self.flat1(x)
        x = self.dens3(x)
        return x

model = CNN()
model.build(input_shape=(None, 224, 224, 12))

ValueError: Input 0 is in compatible with layer mobilenetv2_1.00_224: expected shape=(None, 224, 224, 3), found shape=(None, 224, 224, 12)

我尝试像其他答案一样弹出第一层。

标签: pythontensorflowkeraskeras-layer

解决方案


可以加载两种模型,一种具有 12 通道的输入形状,另一种具有正常的 12 通道。然后,只需将 3 通道模型的权重加载到 12 通道模型中,从第 2 层或第 3 层开始。

这是执行重量转移的地方:

for i in range(3, len(self.base.layers)):
            self.base.layers[i].set_weights(base_weights.layers[i].get_weights())

这是整个事情:

import tensorflow as tf

h, w, c = 224, 224, 3


class CNNModel(tf.keras.Model):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.base = tf.keras.applications.MobileNetV2(input_shape=(h, w, 12),
                                                      include_top=False,
                                                      weights=None)
        base_weights = tf.keras.applications.MobileNetV2(input_shape=(h, w, c),
                                                         include_top=False,
                                                         weights='imagenet')

        for i in range(3, len(self.base.layers)):
            self.base.layers[i].set_weights(base_weights.layers[i].get_weights())

        del base_weights
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.drop1 = tf.keras.layers.Dropout(0.25)
        self.out = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, x, training=None, **kwargs):
        x = self.base(x)
        x = self.pool(x)
        x = self.drop1(x)
        x = self.out(x)
        return x


model = CNNModel()

model.build(input_shape=(None, h, w, 12))

推荐阅读