首页 > 解决方案 > 使用 python 在 TensorFlow 上显示整个 keras CNN 模型的摘要的问题

问题描述

我一直在尝试使用带有“imagenet”权重的“mobilenet v2”模型查看经过训练的图像分类系统的过滤器和特征图,但是,我在执行此操作时遇到了问题。我很确定我知道原因,但我只是不知道如何实现这一点。

我最初是按照 Tensorflow ( https://www.tensorflow.org/tutorials/images/transfer_learning ) 中的一个示例制作的,其中我制作了一个分类模型,并且我想查看在我自己的图像集上训练后的过滤器和特征图。

我找到了几个如何在线查看图层的示例,其中最好的是(https://www.kaggle.com/arpitjain007/guide-to-visualize-filters-and-feature-maps-in-cnn)。

可悲的是,当我尝试查看经过训练的模型的过滤器和特征图时,我找不到任何卷积层。当我总结我的模型时:

model_1 = tf.keras.models.load_model('saved_model/my_model')
model_1.summary()

打印:

Model: "model1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
____________________________________

当我尝试查看图层时,我什么也没得到:

for layer in model_1.layers:

if 'conv' not in layer.name:
    continue    
filters , bias = layer.get_weights()
print(layer.name , filters.shape)

这应该显示 mobilenet 模型有很多但它什么也不返回的所有 conv 层。

我可以通过调用这样的东西来单独查看 mobilenet 层:

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                           include_top=False,
                                           weights='imagenet')
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
base_model.trainable = False
base_model.summary()

    Model: "mobilenetv2_1.00_160"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
    __________________________________________________________________________________________________
    Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
    __________________________________________________________________________________________________
    bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
    __________________________________________________________________________________________________
    Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
    __________________________________________________________________________________________________
    expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
    __________________________________________________________________________________________________
    expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
    __________________________________________________________________________________________________
    expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
    __________________________________________________________________________________________________
    expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
   __________________________________________________________________________________________________
    expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
    __________________________________________________________________________________________________
    block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
   __________________________________________________________________________________________________
    block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
    __________________________________________________________________________________________________
    block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
   __________________________________________________________________________________________________
    block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
   __________________________________________________________________________________________________
    block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
    __________________________________________________________________________________________________
   block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          

Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984

(请注意,我从 mobilenet 摘要中省略了很多层,因为有很多层,我认为它们与这个问题无关)

我认为问题在于定义和总结模型时。我需要能够定义模型,以便它显示“model_1”中的所有层和 base_model 摘要中的 mobilenet 层

我希望它就像调用'model2 = model_1 + base_model'一样简单,但这不起作用。

我希望这是有道理的,并且有人可以提供帮助!

标签: pythontensorflowkerasimage-classification

解决方案


推荐阅读