tensorflow - 修改 TensorFlow 模型的各个层
问题描述
我正在尝试使用来自 inception-resnet-V2 的迁移学习和 imagenet权重来构建模型。这是我构建模型的代码的一部分
input_img_shape = (512, 512, 3)
# inception_resnet_v2 preprocessor:
preprocessor = tf.keras.applications.inception_resnet_v2.preprocess_input
# base inception_resnet_v2 model
base_model = tf.keras.applications.InceptionResNetV2(weights='imagenet', include_top=False,
input_shape=input_img_shape, pooling='avg')
如果我检查摘要
base_model.summary()
我把它作为我的模型参数(我在这里省略了初始层):
.
.
.
conv_7b (Conv2D) (None, 14, 14, 1536) 3194880 block8_10[0][0]
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608 conv_7b[0][0]
__________________________________________________________________________________________________
conv_7b_ac (Activation) (None, 14, 14, 1536) 0 conv_7b_bn[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536) 0 conv_7b_ac[0][0]
==================================================================================================
Total params: 54,336,736
Trainable params: 54,276,192
Non-trainable params: 60,544
我想将 base_model 批量标准化层设为不可训练。我使用以下代码
for layer in base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = False
我得到以下作为我的模型
base_model.summary()
.
.
.
.
conv_7b (Conv2D) (None, 14, 14, 1536) 3194880 block8_10[0][0]
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608 conv_7b[0][0]
__________________________________________________________________________________________________
conv_7b_ac (Activation) (None, 14, 14, 1536) 0 conv_7b_bn[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536) 0 conv_7b_ac[0][0]
==================================================================================================
Total params: 54,336,736
Trainable params: 54,245,920
Non-trainable params: 90,816
我可以看到不可训练的参数增加了约 30K。这是对的。
现在我想向这个模型添加另一层,如下所示:
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')
# put them together
i = tf.keras.layers.Input([None, None, input_img_shape[2]], dtype = tf.uint8)
x = tf.cast(i, tf.float32)
x = preprocessor(x)
x = base_model(x)
x = output_layer(x)
model = tf.keras.Model(inputs=[i], outputs=[x])
型号总结:
model.summary()
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_Cast (TensorFlow [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_RealDiv (TensorF [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_Sub (TensorFlowO [(None, 512, 512, 3)] 0
_________________________________________________________________
inception_resnet_v2 (Functio (None, 1536) 54336736
_________________________________________________________________
dense (Dense) (None, 1) 1537
=================================================================
Total params: 54,338,273
Trainable params: 54,247,457
Non-trainable params: 90,816
到目前为止,这工作正常。现在我需要能够遍历这个模型的各个层并设置tf.keras.layers.BatchNormalization.trainable = True 或 False(对于我的用例)。我需要保存这个模型并重新加载并执行相同的操作 - tf.keras.layers.BatchNormalization.trainable = True 或 False。
修改base_model变量可能会反映对模型变量的更改,但在保存并重新加载模型后我无法执行此操作。
所以我需要一种方法来遍历模型的所有层并设置BatchNormalization.trainable = True 或 False。我不想将所有层都设置为可训练的。使用我用于 base_model 的相同代码不起作用。我不能再将此代码与模型一起使用,因为我的初始架构仅显示为模型摘要中的一个层。
如何遍历模型层并修改各个层?
解决方案
您可以使用递归函数,检查图层的类型是否为tf.keras.Model
:
def set_batch_norm_trainable(model, trainable=False):
for layer in base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = trainable
if isinstance(layer, tf.keras.models.Model):
set_batch_norm_trainable(layer, trainable)
推荐阅读
- c# - 无法访问 ASP.NET Core C# 项目中的嵌入式资源?
- typescript - TypeGraphql:如何在 type-graphql 中触发订阅?
- python - 如何将在我的 html 表单中输入的值带到 python 文件中,然后将其存储在数据库中?
- matlab - 标记坐标和 ginput - Matlab
- windows-community-toolkit - WinUI 3 和 Windows 社区工具包 - InitializeComponent 错误
- react-native - 如何在 react-native 中从 webview 获取 url?
- r - 在 dplyr::mutate_at() 函数中即时定义自定义函数
- ldap - 如何在 ldap 中创建“动态组”?
- python - 将新窗口(pysimplegui)中的列表框值更改为用户可以选择一个选项并暂停主执行
- nlp - NLP 中的“黄金”是什么意思?