python - 第二个 MaxPool2D 之后的宽度、高度比下采样因子小 1
问题描述
我对 ML 和使用 Keras 和 TF 创建 NN 还很陌生,所以我正在按照这里的教程创建一个验证码 ocr 模型,并为我自己的验证码图像稍微调整代码。
我正在使用以下函数来构建网络:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# image dimensions
image_width = 120
image_height = 40
# batch size for training and validation
batch_size = 16
def build_model():
# inputs
input_image = layers.Input(
shape=(image_width, image_height, 1), name='image', dtype='float32'
)
labels = layers.Input(shape = (None,), name='label', dtype='float32')
# first conv block
x = layers.Conv2D(
32,
(3, 3),
activation='relu',
kernel_initializer='he_normal',
padding='same',
name='Conv1'
)(input_image)
x = layers.MaxPool2D((2, 2), name='Pool1')(x)
print()
# second conv block
x = layers.Conv2D(
64,
(3, 3),
activation='relu',
kernel_initializer='he_normal',
name='Conv2'
)(x)
x = layers.MaxPool2D((2, 2), name='Pool2')(x)
# reshape
new_shape = ((image_width // 4), (image_height // 4) * 64)
x = layers.Reshape(target_shape=new_shape, name='Reshape')(x)
x = layers.Dense(64, activation='relu', name='Dense1')(x)
x = layers.Dropout(0.2)(x)
# recurrent block
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)
# output
x = layers.Dense(11, activation='softmax', name='Dense2')(x)
output = CTCLayer(name='CTCLoss')(labels, x)
# define model
model = keras.models.Model(
inputs=[input_image, labels], outputs=output, name='CaptchaOCRv1'
)
model.compile(optimizer=keras.optimizers.Adam())
return model
给定输入图像尺寸 (120, 40, 1) 和两个 (2,2) 池化层,我希望重塑层输入的形状为 (30, 10, 64)。但是我收到以下错误:
ValueError: total size of new array must be unchanged, input_shape = [29, 9, 64], output_shape = [30, 640]
我不明白为什么宽度和高度比因子 4 下采样小 1。谁能指出我的错误?
解决方案
推荐阅读
- c# - C#:Visual Studio 2019:错误没有被捕获
- machine-learning - 在 LSTM 中,应该在训练集和测试集分割之前还是之后进行归一化?
- python - (Python) 为变量分配特定字符在多维数组中出现的次数
- sql - 替换 SQL 中的值
- python - TensorFlow GPU 在 Python 2.7 的 multiprocessing.Process 调用分叉的新进程中不可用
- python - Msys2 升级中断 python2-pyqt5
- laravel - 无法通过 Laravel Backpack 中的方法覆盖订单
- node.js - 使用强大/快速时文件不写入磁盘
- javascript - 无法使选择过滤器 JavaScript?
- javascript - 无法使用 Google 的 node.js 客户端库生成 JWT 客户端。运行代码时出现 Typeerror