python - Keras.model.summary 没有正确显示我的模型..?
问题描述
我想通过keras.model.summary查看我的模型的摘要,但是效果不好。我的代码如下:
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation = 'relu')
self.flatten = Faltten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'relu')
def trythis(self,x):
a = BatchNormalization()
b = a(x)
return b
def call(self, x):
x = self.conv1(x)
x = trythis(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
model = MyModel()
model.build((None, 32,32,3))
model.summary()
我期望 BatchNorm 层,但总结如下:
Model: "my_model_30"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_31 (Conv2D) multiple 896
_________________________________________________________________
flatten_30 (Flatten) multiple 0
_________________________________________________________________
dense_60 (Dense) multiple 3686528
_________________________________________________________________
dense_61 (Dense) multiple 1290
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0
它不包含“trythis”方法中的 BatchNorm 层。
我怎么解决这个问题?
感谢您的阅读。
解决方案
子类模型的形状推断不像功能 API 中那样自动。所以我在子类模型中添加了一个模型调用,并定义了一个功能模型,如下所示。请注意,有几种方法可以做,我展示的是一种方法。请查看我在此处回答的类似问题的更多详细信息
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation = 'relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'relu')
def trythis(self,x):
a = BatchNormalization()
b = a(x)
return b
def call(self, x):
x = self.conv1(x)
x = MyModel.trythis(self,x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
def model(self):
x = tf.keras.layers.Input(shape=(32, 32, 3))
return Model(inputs=[x], outputs=self.call(x))
model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()
总结如下
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32) 128
_________________________________________________________________
flatten_4 (Flatten) (None, 28800) 0
_________________________________________________________________
dense_8 (Dense) (None, 128) 3686528
_________________________________________________________________
dense_9 (Dense) (None, 10) 1290
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________
推荐阅读
- python - 解析网页时无法提取单行
- java - 如何用一维数组中的值填充二维数组?
- c - 嵌入式 C 中的 __forceinline
- vue.js - Vue:有条件地允许基于另一个道具的值的道具类型
- qt - QNetworkAccessManager 是否支持 HTTPS 代理?
- python - 我正在尝试将来自爬虫的信息放入 json 文件中,但是当我添加新对象时出现 json 多个顶级错误,我该如何解决这个问题?
- react-native - 在本机反应中不导航到特定屏幕
- azure - 添加的客户端 IP 不会保留在 Azure 服务器防火墙设置中
- appcode - 在 AppCode 中与 Swift 包管理器同步失败:尝试写入只读数据库
- angular - 如何在材质Angular中将一个对话弹出窗口置于另一个之上