首页 > 解决方案 > Keras.model.summary 没有正确显示我的模型..?

问题描述

我想通过keras.model.summary查看我的模型的摘要,但是效果不好。我的代码如下:

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Faltten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = trythis(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((None, 32,32,3))
model.summary()

我期望 BatchNorm 层,但总结如下:

Model: "my_model_30"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_31 (Conv2D)           multiple                  896       
_________________________________________________________________
flatten_30 (Flatten)         multiple                  0         
_________________________________________________________________
dense_60 (Dense)             multiple                  3686528   
_________________________________________________________________
dense_61 (Dense)             multiple                  1290      
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0

它不包含“trythis”方法中的 BatchNorm 层。

我怎么解决这个问题?

感谢您的阅读。

标签: pythontensorflowkeras

解决方案


子类模型的形状推断不像功能 API 中那样自动。所以我在子类模型中添加了一个模型调用,并定义了一个功能模型,如下所示。请注意,有几种方法可以做,我展示的是一种方法。请查看我在此处回答的类似问题的更多详细信息

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = MyModel.trythis(self,x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)
    def model(self):
        x = tf.keras.layers.Input(shape=(32, 32, 3))
        return Model(inputs=[x], outputs=self.call(x))

model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()

总结如下

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32)        128       
_________________________________________________________________
flatten_4 (Flatten)          (None, 28800)             0         
_________________________________________________________________
dense_8 (Dense)              (None, 128)               3686528   
_________________________________________________________________
dense_9 (Dense)              (None, 10)                1290      
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________

推荐阅读