python - 如何使用 tensorflow 数据集训练神经网络?
问题描述
我正在尝试在 emnist 数据集上训练神经网络,但是当我尝试展平图像时,它会引发以下错误:
WARNING:tensorflow:Model 是用形状 (None, 28, 28) 构造的输入 Tensor("flatten_input:0", shape=(None, 28, 28), dtype=float32),但它是在不兼容的输入上调用的形状(无、1、28、28)。
我无法弄清楚似乎是什么问题,并尝试更改我的预处理,从我的 model.fit 和我的 ds.map 中删除批量大小。
这是完整的代码:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
def preprocess(dict):
image = dict['image']
image = tf.transpose(image)
label = dict['label']
return image, label
train_data, validation_data = tfds.load('emnist/letters', split = ['train', 'test'])
train_data_gen = train_data.map(preprocess).shuffle(1000).batch(32)
validation_data_gen = validation_data.map(preprocess).batch(32)
print(train_data_gen)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape = (28, 28)),
tf.keras.layers.Dense(128, activation = 'relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = ['accuracy'])
early_stopping = keras.callbacks.EarlyStopping(monitor = 'val_accuracy', patience = 10)
history = model.fit(train_data_gen, epochs = 50, batch_size = 32, validation_data = validation_data_gen, callbacks = [early_stopping], verbose = 1)
model.save('emnistmodel.h5')
解决方案
所以这里实际上发生了一些事情,所以让我们一次解决它们。
输入形状
因此,要解决您的直接问题,您会收到不兼容的形状错误,因为输入的形状与预期的形状不匹配。
在这一行
tf.keras.layers.Flatten(input_shape=(28, 28)),
中,我们告诉模型期望形状为 (28, 28) 的输入,但这并不准确。我们的输入实际上具有形状 (28, 28, 1),因为我们正在拍摄具有1 个通道的 28x28 像素图像(而不是具有 3 个通道 r、g 和 b 的彩色图像)。所以为了解决这个直接的问题,我们只需更新模型以使用输入的形状。IEtf.keras.layers.Flatten(input_shape=(28, 28, 1)),
输出节点数
正如 Rishabh 在他的回答中所建议的那样,EMNIST 数据集有超过 10 个平衡类。但是,在您的情况下,您似乎正在使用具有 26 个平衡类的 EMNIST Letters。因此,您的神经网络应该相应地具有 27 个输出节点(因为类标签从 1.. 26 开始,而我们的输出节点对应于 0.. 26)才能对给定数据进行分类。当然,给它额外的输出节点也可以让它运行,但是这些会给我们额外的权重来训练,这不是必需的,这会增加我们模型所需的训练时间。简而言之,你的最后一层应该是
tf.keras.layers.Dense(27, activation='softmax')
预处理 TensorFlow 数据集
阅读您的 preprocess() 函数,我相信您正在尝试将训练和验证数据集转换为(图像,标签)的元组。TensorFlow 没有创建我们自己的函数,而是通过参数as_supervised方便地为我们实现了这一点。
此外,我看到您尝试实现的一些额外预处理,例如对数据进行批处理和洗牌。同样,TensorFlow 为我们实现了batch_size和shuffle_files(参见常用参数)!所以加载数据集看起来像
train_data, validation_data = tfds.load('emnist/letters', split=['train', 'test'], shuffle_files=True, batch_size=32, as_supervised=True)
一些附加说明
此外,作为建议,请考虑
batch_size
从 model.fit() 中排除。在两个不同的地方定义相同的东西会导致错误和意外行为。此外,当使用 TensorFlow 数据集时,没有必要,因为它们已经生成了批次。
总体而言,您更新的程序应如下所示
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow import keras
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
train_data, validation_data = tfds.load('emnist/letters',
split=['train', 'test'],
shuffle_files=True,
batch_size=32,
as_supervised=True)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(27, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_accuracy', patience=10)
history = model.fit(train_data,
epochs=50,
validation_data=validation_data,
callbacks=[early_stopping],
verbose=1)
model.save('emnistmodel.h5')
希望这可以帮助!
推荐阅读
- reactjs - 如何使用 Formik 管理嵌套的复杂对象
- c# - 如何在 ASP.NET MVC 项目中下载 PDF
- entity-framework - AsNoTracking 与 HasNoKey
- android - SAF - 将文件从私有应用程序文件夹复制到授权的 SAF 文件夹时出现 NullPointerException
- c - 使用堆栈而不是 C 中的堆进行动态内存分配
- javascript - 将制造提前期添加到网站 Javascript 到 HTML
- python - 如何从文本匹配组中排除某些字符?
- node.js - Linux上的SharpJS管道内存泄漏
- javascript - 如何在anime.js中链接动画?
- java - 找到接口 org.apache.poi.util.POILogger,但预期类错误