首页 > 解决方案 > 修改 TensorFlow 模型的各个层

问题描述

我正在尝试使用来自 inception-resnet-V2 的迁移学习和 imagenet重来构建模型。这是我构建模型的代码的一部分

input_img_shape = (512, 512, 3)
# inception_resnet_v2 preprocessor: 
preprocessor = tf.keras.applications.inception_resnet_v2.preprocess_input
# base inception_resnet_v2 model
base_model = tf.keras.applications.InceptionResNetV2(weights='imagenet', include_top=False, 
input_shape=input_img_shape, pooling='avg')

如果我检查摘要

base_model.summary()

我把它作为我的模型参数(我在这里省略了初始层):

.
.
.
conv_7b (Conv2D)                (None, 14, 14, 1536) 3194880     block8_10[0][0]                  
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608        conv_7b[0][0]                    
__________________________________________________________________________________________________
conv_7b_ac (Activation)         (None, 14, 14, 1536) 0           conv_7b_bn[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536)         0           conv_7b_ac[0][0]                 
==================================================================================================
Total params: 54,336,736
Trainable params: 54,276,192
Non-trainable params: 60,544

我想将 base_model 批量标准化层设为不可训练。我使用以下代码

for layer in base_model.layers:
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = False

我得到以下作为我的模型

base_model.summary()


.
.
.
.
conv_7b (Conv2D)                (None, 14, 14, 1536) 3194880     block8_10[0][0]                  
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608        conv_7b[0][0]                    
__________________________________________________________________________________________________
conv_7b_ac (Activation)         (None, 14, 14, 1536) 0           conv_7b_bn[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536)         0           conv_7b_ac[0][0]                 
==================================================================================================
Total params: 54,336,736
Trainable params: 54,245,920
Non-trainable params: 90,816

我可以看到不可训练的参数增加了约 30K。这是对的。

现在我想向这个模型添加另一层,如下所示:


output_layer = tf.keras.layers.Dense(1, activation='sigmoid') 

# put them together
i = tf.keras.layers.Input([None, None, input_img_shape[2]], dtype = tf.uint8)
x = tf.cast(i, tf.float32)
x = preprocessor(x) 
x = base_model(x)
x = output_layer(x)
model = tf.keras.Model(inputs=[i], outputs=[x])

型号总结:

model.summary()

Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_Cast (TensorFlow [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_RealDiv (TensorF [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_Sub (TensorFlowO [(None, 512, 512, 3)]     0         
_________________________________________________________________
inception_resnet_v2 (Functio (None, 1536)              54336736  
_________________________________________________________________
dense (Dense)                (None, 1)                 1537      
=================================================================
Total params: 54,338,273
Trainable params: 54,247,457
Non-trainable params: 90,816

到目前为止,这工作正常。现在我需要能够遍历这个模型的各个层并设置tf.keras.layers.BatchNormalization.trainable = True 或 False(对于我的用例)。我需要保存这个模型并重新加载并执行相同的操作 - tf.keras.layers.BatchNormalization.trainable = True 或 False

修改base_model变量可能会反映对模型变量的更改,但在保存并重新加载模型后我无法执行此操作。

所以我需要一种方法来遍历模型的所有层并设置BatchNormalization.trainable = True 或 False。我不想将所有层都设置为可训练的。使用我用于 base_model 的相同代码不起作用。我不能再将此代码与模型一起使用,因为我的初始架构仅显示为模型摘要中的一个层。

如何遍历模型层并修改各个层?

标签: tensorflowkeras

解决方案


您可以使用递归函数,检查图层的类型是否为tf.keras.Model

def set_batch_norm_trainable(model, trainable=False):
    for layer in base_model.layers:
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = trainable
        if isinstance(layer, tf.keras.models.Model):
            set_batch_norm_trainable(layer, trainable)

推荐阅读