python - Keras 创建 CNN 模型“添加的层必须是类层的实例”
问题描述
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import Dropout, Flatten, Input, Dense
def create_model():
def add_conv_block(model, num_filters):
model.add(Conv2D(num_filters, 3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(Conv2D(num_filters, 3, activation='relu', padding='valid'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.2))
return model
model = tf.keras.models.Sequential()
model.add(Input(shape=(32, 32, 3)))
model = add_conv_block(model, 32)
model = add_conv_block(model, 64)
model = add_conv_block(model, 128)
model.add(Flatten())
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model = create_model()
model.summary()
解决方案
解决方案是使用InputLayer
而不是Input
. InputLayer
旨在与Sequential
模型一起使用。您也可以InputLayer
完全省略并input_shape
在顺序模型的第一层中指定。
Input
旨在与 TensorFlow Keras 功能 API 一起使用,而不是与顺序 API 一起使用。
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import Dropout, Flatten, InputLayer, Dense
def create_model():
def add_conv_block(model, num_filters):
model.add(Conv2D(num_filters, 3, activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(Conv2D(num_filters, 3, activation='relu', padding='valid'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.2))
return model
model = tf.keras.models.Sequential()
model.add(InputLayer((32, 32, 3)))
model = add_conv_block(model, 32)
model = add_conv_block(model, 64)
model = add_conv_block(model, 128)
model.add(Flatten())
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model = create_model()
model.summary()
推荐阅读
- vb.net - 显示新表单时的视觉基本问题
- azure - 如何通过terraform将本地文件复制到azure vm?
- c++ - 3d 三角形卡住了 c++ opengl glfw
- flutter - 如何为给定类型的 Flutter 小部件强制执行某些参数值?
- javascript - 如何将参数注入 TestCafé 测试?
- authorization - 微服务中的授权 - 用户可以访问的行列表
- sql - SQL Server:为临时表创建聚集索引
- python - 计算器的功率增加字符串
- php - 如何为数据库查询返回给我的每一行制作饼图?
- django - 从 django web 框架中的引导按钮调用 python 函数