首页 > 解决方案 > 如何将从 image_dataset_from_directory 获得的数据集拆分为数据和标签?

问题描述

我正在尝试使用 Python 在 TensorFlow 中构建 CNN。我已将图像加载到数据集中,如下所示:

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "train_data", shuffle=True, image_size=(578, 260),
    batch_size=BATCH_SIZE)

但是,如果我想在这个数据集上使用 train_test_split 或 fit_resample,我需要将它分成数据和标签。我是 TensorFlow 新手,不知道该怎么做。非常感谢任何帮助。

标签: pythontensorflowkerastensorflow-datasets

解决方案


您可以使用该subset参数将数据分隔为trainingvalidation

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)


train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  image_size=(256, 256),
  seed=1,
  batch_size=32)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=1,
  image_size=(256, 256),
  batch_size=32)

for x, y in train_ds.take(1):
  print('Image --> ', x.shape, 'Label --> ',  y.shape)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Image -->  (32, 256, 256, 3) Label -->  (32,)

至于你的标签,根据文档

“推断”(从目录结构生成标签)、无(无标签)或与目录中找到的图像文件数量相同大小的整数标签列表/元组。标签应根据图像文件路径的字母数字顺序排序(通过 Python 中的 os.walk(directory) 获得)。

所以只需尝试迭代train_ds并查看它们是否存在。您还可以使用参数label_mode来引用您拥有的标签类型并class_names明确列出您的类。

如果您的类不平衡,您可以使用 的class_weights参数model.fit(*)。有关更多信息,请查看此帖子


推荐阅读