python - 用于识别序列数据的 CNN 模型的配置 - CNN 顶部的架构 - 并行层
问题描述
我正在尝试配置一个网络以识别车牌等顺序数据的字符。现在我想使用深度自动车牌识别系统中表 3 中提到的架构(链接:http://www.ee.iisc.ac.in/people/faculty/soma.biswas/Papers/jain_icgvip2016_alpr。 .pdf )。
作者提出的架构是这样的:
第一层很常见,但我绊倒的是架构的顶部(红框内的部分)。他们提到了 11 个并行层,我真的不确定如何在 Python 中得到它。我编写了这个架构,但它似乎不适合我。
model = Sequential()
model.add(Conv2D(64, kernel_size=(5, 5), input_shape = (32, 96, 3), activation = "relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=(3, 3), activation = "relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=(3, 3), activation = "relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1024, activation = "relu"))
model.add(Dense(11*37, activation="Softmax"))
model.add(keras.layers.Reshape((11, 37)))
有人可以帮忙吗?我如何必须对顶部进行编码才能获得与作者一样的平等架构?
解决方案
下面的代码可以构建图像中描述的架构。
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, Flatten, MaxPooling2D, Dense, Input, Reshape, Concatenate, Dropout
def create_model(input_shape = (32, 96, 1)):
input_img = Input(shape=input_shape)
'''
Add the ST Layer here.
'''
model = Conv2D(64, kernel_size=(5, 5), input_shape = input_shape, activation = "relu")(input_img)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = Dropout(0.25)(model)
model = Conv2D(128, kernel_size=(3, 3), input_shape = input_shape, activation = "relu")(model)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = Dropout(0.25)(model)
model = Conv2D(256, kernel_size=(3, 3), input_shape = input_shape, activation = "relu")(model)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = Dropout(0.25)(model)
model = Flatten()(model)
backbone = Dense(1024, activation="relu")(model)
branches = []
for i in range(11):
branches.append(backbone)
branches[i] = Dense(37, activation = "softmax", name="branch_"+str(i))(branches[i])
output = Concatenate(axis=1)(branches)
output = Reshape((11, 37))(output)
model = Model(input_img, output)
return model
推荐阅读
- javascript - 如何使用 Spring JPA 和 JQUERY ajax 构建和发送实体?
- flutter - Flutter2 升级:为什么 build_runner 包不是空安全的
- java - Hibernate SQLQuery.list() 获取大量记录时非常慢
- java - 当在 android 10 应用程序崩溃但 android 11 中创建目录时它运行良好
- linux - 如何从 Linux 设备树中导出 GPIO 输出引脚
- c++ - 十进制转二进制,再转成字符串矩阵
- python - SonarQube Jenkins 以自动方式设置?
- go - 如何通过 Golang 中的反射更新地图值
- python - 解决通道/异步测试用例上的 django 测试 RuntimeError
- go - golang 如何在运行时运行智能类型断言?