python - TensorFlow 数据集的意外维度
问题描述
我正在尝试在 MNIST 数据集上使用 InceptionV3 进行迁移学习。
计划是读取 MNIST 数据集,调整图像大小,然后使用它们进行训练,如下所示:
import numpy as np
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.python.keras.applications import InceptionV3
tfv1.enable_v2_behavior()
print(tf.version.VERSION)
img_size = 299
def preprocess_tf_image(image, label):
image = tf.image.grayscale_to_rgb(image)
image = tf.image.resize(image, [img_size, img_size])
return image, label
#Acquire MNIST data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#Convert data to [0,1] range
x_train, x_test = x_train / 255.0, x_test / 255.0
#Add extra dimension to images so that they can be converted to RGB
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape (x_test.shape[0], 28, 28, 1)
x_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
x_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
#Convert images to RGB space and resize
x_train = x_train.map(preprocess_tf_image)
x_test = x_test.map(preprocess_tf_image)
img_shape = (img_size, img_size, 3)
#Get trained model, but leave off the head
base_model = InceptionV3(input_shape = img_shape, weights='imagenet', include_top=False)
base_model.trainable = False
#Make a model with a new head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
#Compile model
model.compile(
optimizer='adam', #tf.keras.optimizers.RMSprop(lr=BASE_LEARNING_RATE),
loss='binary_crossentropy',
metrics=['accuracy']
)
model.fit(x_train, epochs=5)
model.evaluate(x_test)
但是,当我运行它时,事情会model.fit()
因错误而停止:
ValueError:检查输入时出错:预期 inception_v3_input 有 4 个维度,但得到了形状为 (299、299、3) 的数组
这是怎么回事?
解决方案
应用map
到数据集后,响应没有关于批量大小的信息,您必须调用batch
函数来添加它:
x_train = x_train.batch(batch_size = BATCH_SIZE) # adds batch size dimension to train dataset
x_test = x_test.batch(batch_size = BATCH_SIZE) # idem for test.
之后,我可以使用 Google 的 Colab 完全训练和评估模型,您可以在此处查看。
推荐阅读
- ios - Swift - Mopub 原生广告已成功加载,但未显示在视图中
- macos - 无法使用安装了 macports 的 MPICH mpirun 执行 MPI 程序
- jquery - 检索引导复选框的单击状态
- android - 防止网页链接打开本机应用程序,留在网页浏览器中
- c# - EF Code First-从多到多表返回记录
- ios - Stripe SDK integration using swift and flutter
- java - How can I have fonts inside project and use from project folder in NetBeans IDE?
- node.js - 使用 node-rdkafka 的 Kerberos SASL 身份验证
- html - 一种当设备尺寸较小时,部分将不会被包含在内
- python - API 网关 - 自定义授权方不工作