python - tf.keras Conv2D 层的输入大小不合适
问题描述
我正在按照教程中概述的步骤进行操作
我正在尝试在 Google Colaboratory 笔记本内的单元格中运行教程中的以下代码:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) =
tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
model = tf.keras.Sequential()
# Must define the input shape in the first layer of the neural network
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# Take a look at the model summary
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(x_train,
y_train,
batch_size=64,
epochs=10)
# Evaluate the model on test set
score = model.evaluate(x_test, y_test, verbose=0)
# Print test accuracy
print('\n', 'Test accuracy:', score[1])
当我运行 cell 时,出现以下错误:
Error when checking input: expected conv2d_5_input to have 4 dimensions, but got array with shape (60000, 28, 28)
我觉得我遗漏了一些对卷积层的使用至关重要的东西,但看起来这似乎应该有效。我在 SO 上发现了一些类似的问题,人们建议操纵“input_shape”参数。我尝试将其更改为 (60000, 28, 28) 并且还添加了值为 1 的附加维度,但到目前为止没有任何效果。谁能指出我在这里可能遗漏了什么?
解决方案
看起来您跳过了教程中的重塑部分:
# Reshape input data from (28, 28) to (28, 28, 1)
w, h = 28, 28
x_train = x_train.reshape(x_train.shape[0], w, h, 1)
x_valid = x_valid.reshape(x_valid.shape[0], w, h, 1)
x_test = x_test.reshape(x_test.shape[0], w, h, 1)
这里的想法是您的样本是 28x28x1(一种颜色,28x28 像素),第一维 - 样本的数量(在您的情况下为 60000)。
推荐阅读
- spark-streaming - 使用 HAIL 解析 .bgen 文件,而不在单个节点上加载数据
- python - django-heroku 没有安装
- chart.js - 如何更改 Charts.js 中水平条的起点
- java - 无法使用具有文本逗号 csv 的值拆分字符串
- excel - Excel 文件缺少对 Microsoft Windows Common Controls-2 的引用
- python - pydrive.auth.RefreshError:访问令牌刷新失败:invalid_grant:令牌已过期或撤销
- c - 使用 CGO 将 Go 嵌套的结构数组转换为 C?
- javascript - NodeList 元素上的 forEach() 函数
- kubernetes - 如何在 kubernetes v1.19.0 上将“--token-auth-file=SOMEFILE”标志设置为 apiserver
- kotlin - Kotlin 中可空字符串的重载解析歧义错误