tensorflow - Tensorflow Keras 形状不匹配
问题描述
在尝试实现许多教程用来向您介绍神经网络的标准 MNIST 数字识别器时,我遇到了错误
ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (28, 10)).
我想from_tensor_slices
用来处理数据,因为我想将代码应用于数据来自 CSV 文件的另一个问题。无论如何,这是在该行中产生错误的代码model.fit(...)
import tensorflow as tf
train_dataset, test_dataset = tf.keras.datasets.mnist.load_data()
train_images, train_labels = train_dataset
train_images = train_images/255.0
train_dataset_tensor = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
num_of_validation_data = 10000
validation_data = train_dataset_tensor.take(num_of_validation_data)
train_data = train_dataset_tensor.skip(num_of_validation_data)
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, activation='sigmoid'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.fit(train_data, batch_size=50, epochs=5)
performance = model.evaluate(validation_data)
我不明白(28, 10)
logits 的形状从何而来,我以为我在展平图像,本质上是从 2D 图像中制作出 1D 矢量?如何防止错误?
解决方案
您可以使用以下代码
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, activation='sigmoid'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.fit(train_ds)
推荐阅读
- flutter - 如何通过向下拖动来关闭小部件
- xml - SpringBoot:获取方法不产生预期的 XML 元素
- bots - DiscordAPIError:未知角色
- xpath - Scrapy:如何从通过ajax加载的页面中提取数据?
- python - 制作提取帧的视频时出错
- python - 我正在尝试编写代码以保存到 excel 但它不断覆盖
- azure-devops - 无法删除 Azure Dev Ops 中的测试计划/套件
- amazon-web-services - AWS Cloudformation cfn-init 未安装软件包
- amazon-web-services - 如何在 AWS 上接收数据 MQTT
- bots - 我如何制作欢迎命令并添加选项来设置它将发送到的频道?