首页 > 解决方案 > ds_train 的形状为 (2, 224, 224, 3) 而不是 (None, 224, 224, 3)

问题描述

我使用以下代码创建了自己的自定义数据集(有 2 个类):

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt
ds_train = tf.keras.preprocessing.image_dataset_from_directory(
        'C:/Users/mydir/Source_Images/',
        labels = 'inferred', # from subfolders in alphabetical order
        label_mode = "int",
        class_names = ["CVS", "No_CVS"],
        color_mode = 'rgb',
        batch_size = 2,
        image_size = (224, 224),
        shuffle = True, # randomized order of images
        seed = 123, #set the seed if  train, valid images are the same when you run again   
        validation_split = 0.1,
        subset = "training"
        )

df_train 结果:

<BatchDataset shapes: ((None, 224, 224, 3), (None,)), types: (tf.float32, tf.int32)>

现在,我想通过查看 9 张图像来可视化我的数据:

for i, (image, label) in enumerate(ds_train.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("uint8"))
    plt.axis("off")

但是,我收到以下错误:

第 61 行,在

plt.imshow(image.numpy().astype("uint8"))
TypeError: Invalid shape (2, 224, 224, 3) for image data

我正在寻找一种方法来解决这个问题,并能够用 matplotlib 绘制我的图像。

编辑:

更重要的是,似乎在训练模型时也无法使用数据的数据,因为我收到了这个错误:

   ValueError: Input 0 is incompatible with layer EfficientNet: expected shape=(None, 224, 224, 3), found shape=(2, None, 224, 224, 3)

运行我在此处找到的 Keras 示例代码后(我在其中创建了 ds_train ,image_dataset_from_directory而不是tdsf.load()函数)。

所以我认为我创建 ds_train 的方式出了点问题。任何决议都非常受欢迎。

标签: pythonimagetensorflowmatplotlibkeras

解决方案


当您这样做时,似乎您要离开了batch_size

plt.imshow(image.numpy().astype("uint8"))

使用您的原始代码,您仍然无法看到 9 张图像,因为您的batch_size. 我认为如果你这样做会很好:

不应抛出任何错误,例如TypeError: Invalid shape...

plt.imshow(image[i].numpy().astype("uint8"))

此外,您可以执行以下操作来查看 batch_size:

for img_batch_size, labels_batch_size in train_df:
  print(img_batch_size.shape)
  print(labels_batch_size.shape)

对于您的情况img_batch_size.shape,应打印 (2,224,224,3) 此元组对应于图像张量的位置。

对于input_shape问题,您需要添加您的模型,以便我们可以看到有什么问题input_shape


推荐阅读