python - 如何在连接的 keras 模型中设置可训练参数
问题描述
原始代码太笨拙了,所以我将尝试用一个简化的例子来解释这个问题。
首先,导入我们需要的库:
import tensorflow as tf
from keras.applications.resnet50 import ResNet50
from keras.models import Model
from keras.layers import Dense, Input
然后加载一个预训练模型并打印出摘要。
model = ResNet50(weights='imagenet')
model.summary()
这是“摘要”的输出:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
(我削减了summary()
函数的输出以节省一些空间。)现在,所有层参数都是可训练的。例如,我将一个可训练参数设置False
为如下。
model.get_layer('bn5c_branch2c').trainable = False
现在,除了bn5c_branch2c层之外,所有层仍然是可训练的。
接下来,使用这个原始模型创建一个新模型,但让它成为一个连接模型。
in1 = Input(shape=(224, 224, 3), name="in1")
in2 = Input(shape=(224, 224, 3), name="in2")
out1 = model(in1)
out2 = model(in2)
new_model = Model(inputs=[in1, in2], outputs=[out1, out2])
并再次打印出摘要:
new_model.summary()
和输出:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
in1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
in2 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
resnet50 (Model) (None, 1000) 25636712 in1[0][0]
in2[0][0]
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
在这一点上,我已经失去了查看哪些层可训练和不可训练的能力,因为原始 ResNet50 模型的所有层现在都显示为一个单层。如果我运行以下代码,它会给我True
:
new_model.get_layer('resnet50').trainable # Returns True
问题 1)我确实在模型中将层bn5c_branch2c的可训练参数设置为 False 。我可以假设bn5c_branch2c的可训练值即使在 new_model 中仍然是 False 吗?
问题2)如果上述问题的答案是肯定的(意味着在new_model中bn5c_branch2c层的可训练参数值仍然是False)......如果我稍后保存这个new_model的架构和权重,并再次加载它们以进一步训练这个 new_model... 我可以相信bn5c_branch2c的可训练参数值将保持为 False 吗?
解决方案
注意:您可以使用.layers[idx]
属性访问模型的层,其中idx
是模型中层的索引(从零开始)。或者,如果您为图层设置了名称,则可以使用.get_layer(layer_name)
方法访问它们。
A1)是的,您可以通过以下方式确认:
print(new_model.layers[2].get_layer('bn5c_branch2c').trainable) # output: False
此外,您可以通过查看模型摘要中不可训练参数的数量来确认这一点。
A2)是的,您可以通过以下方式确认:
# save it
new_model.save('my_new_model.hd5')
# load it again
new_model = load_model('my_new_model.hd5')
print(new_model.layers[2].get_layer('bn5c_branch2c').trainable) # output: False
推荐阅读
- python - mapping values from one dataframe to another dataframe
- java - 通过继承和类型参数化存储库访问 JPA 方法
- mysql - Mysql 按年龄分组
- hyperledger-fabric - BAD_REQUEST - 在 Hyperledger 教程中验证新通道的通道创建交易时出错
- reactjs - 我怎样才能为每个人拥有一个唯一的密钥
- reactjs - 如何使用类组件从Reactjs中的父组件调用多级子组件中的函数
- c# - 从 C# 中的内存流向 mimekit mimemessage 添加附件?
- rest - 用于通过 IBM 的 Maximo REST API 进行身份验证的 URL
- jms - ActiveMQ QueueBrowser 是否将队列的全部内容加载到内存中?
- tensorflow - tensorflow.js 在检查输入时出错:预期的 dense_Dense1_input 具有 3 个维度。但是有形状的数组