tensorflow - 如何修复 MNIST 手写数据集的 lstm 和 cnn 代码
问题描述
我正在为 MNIST 手写数据集编写一个代码中的 LSTM+CNN,如何解决维度问题?
我为 MNIST 手写数据集分别编码了 LSTM 和 CNN,但合并有问题
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.layers import Dense, Dropout, LSTM
################### Loading dataset ##########################
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1))
x_test = x_test.reshape((10000, 28, 28, 1))
################### Normalizing dataset ######################
x_train, x_test = x_train / 255.0, x_test / 255.0
################### Building a model #########################
ConvNN_model = models.Sequential()
ConvNN_model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
ConvNN_model.add(layers.MaxPooling2D((2, 2)))
ConvNN_model.add(layers.Conv2D(64, (3, 3), activation='relu'))
ConvNN_model.add(LSTM(128, activation='relu'))
ConvNN_model.add(Dropout(0.2))
ConvNN_model.add(layers.Dense(64, activation='relu'))
ConvNN_model.add(layers.Dropout(0.25))
ConvNN_model.add(layers.Dense(10, activation='softmax'))
################### Compiling a model ########################
ConvNN_model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
################### Fitting a model ##########################
ConvNN_model.fit(x = x_train,
y = y_train,
epochs = 1,
validation_data = (x_test, y_test))
我遇到了这个问题:
ValueError Traceback (last last call last) in () 23 ConvNN_model.add(layers.Conv2D(64, (3, 3), activation='relu')) 24 ---> 25 ConvNN_model.add(LSTM(128, activation ='relu')) 26 ConvNN_model.add(Dropout(0.2)) 27
ValueError: Input 0 of layer lstm_7 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, 11,11, 64]
你能帮我解决这个问题吗?先感谢您。
解决方案
这里的诀窍是使用TimeDistributed
. 由于您在行和列方向都有序列,因此需要首先对其中一个进行编码。下面我们首先使用包裹在 TimeDistributed 中的 LSTM 对行进行编码,然后使用 LSTM 对列进行编码。
from tensorflow.keras import models, layers
from tensorflow.keras.layers import LSTM, Dropout, Dense, TimeDistributed
################### Building a model #########################
ConvNN_model = models.Sequential()
ConvNN_model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
ConvNN_model.add(layers.MaxPooling2D((2, 2)))
ConvNN_model.add(layers.Conv2D(64, (3, 3), activation='relu'))
# encode rows of matrix
ConvNN_model.add(TimeDistributed(LSTM(128, activation='relu')))
ConvNN_model.add(Dropout(0.2))
# encode columns
ConvNN_model.add(LSTM(128, activation='relu'))
ConvNN_model.add(layers.Dense(64, activation='relu'))
ConvNN_model.add(layers.Dropout(0.25))
ConvNN_model.add(layers.Dense(10, activation='softmax'))
################### Compiling a model ########################
ConvNN_model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
################### Fitting a model ##########################
ConvNN_model.fit(x = x_train,
y = y_train,
epochs = 1,
validation_data = (x_test, y_test))
推荐阅读
- c - 共享内存中的简单检查返回 SIGSEGV 错误 008b
- php - 表单未收集用户输入,即使我输入了有效的表和列,查询也无法正常工作
- java - 如何使用 Cucumber 和 Rally 集成自动更新 Rally 测试用例?
- bash - 如何在我的 git 环境中修复 pre-commit-msg?
- optimization - 如何使用严格的不等式比较浮点变量?
- javascript - 我的 .on Swipe 在中断之前只能工作一次并且必须刷新所有内容 - JQuery Mobile / Javascript
- xml - 创建maven项目时如何修复“找不到前缀'原型'的插件”错误
- java - 如何减少工具栏上的文本和汉堡图标之间的差距
- c# - 路径中的非法字符
- bash - tr 命令在 bash 脚本中失败