首页 > 解决方案 > Tensorflow 中 Flatten() 层中输入和输出形状的问题

问题描述

我正在尝试为图像分类构建一个基本的卷积神经网络。我有一个属于 4 个类的图像数据集。ImageDataGenerator我使用and创建了 TensorFlow 数据集flow_from_directory

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2
)

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_data = train_datagen.flow_from_directory(
    dataset_source, 
    color_mode='rgb',
    batch_size=batch_size, 
    class_mode='categorical',
    shuffle=True,
    subset='training'
) 

val_data = val_datagen.flow_from_directory(
    dataset_source, 
    color_mode='rgb',
    batch_size=batch_size, 
    class_mode='categorical',
    shuffle=False,
    subset='validation'
)

我测试的 CNN 架构示例很简单,如下所示:

image_shape = (299, 299, 3)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=image_shape))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(4))

model.compile(loss = "categorical_crossentropy",
             optimizer = "adam",
             metrics = ["accuracy"])

每当我尝试使用 开始训练model.fit(train_data, epochs=100, validation_data = val_data)时,都会收到以下与Flatten()层相关的错误:

InvalidArgumentError:  Input to reshape is a tensor with 7372800 values, but the requested shape requires a multiple of 645248
[[node sequential_1/flatten_1/Reshape (defined at <ipython-input-9-d01204576b1d>:1) ]] [Op:__inference_train_function_1823]

我使用各种版本的 CNN 对其进行了检查,即使是最简单的包含单个卷积层,但每次都会出现错误,只是值不同。这可能很明显,但是尽管尝试了很多次,我还是无法解决它,所以我会非常感谢任何关于如何处理我的问题的指导。

标签: pythontensorflowconv-neural-network

解决方案


在不指定 a 的情况下使用flow_from_directory调用target_size只是自找麻烦,因为即使是单个大小错误的输入也会扰乱您的模型拟合过程。简单的解决方法是为了target_size=(299, 299)安全起见添加到调用中(如评论中所示)。

此外,请考虑在调用image_shape = (299, 299, 3)上方添加并指定从那里提取它。这将使修改更容易,您可能想要尝试各种输入大小缩放。flow_from_directorytarget_size = image_shape[:-1]image_shape


推荐阅读