首页 > 解决方案 > 减少数据集的类

问题描述

假设我有这样初始化的 CIFAR-100(images) 数据集:

cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")

例如 ds_train 是一个类型的对象:

<DatasetV1Adapter shapes: {coarse_label: (), image: (32, 32, 3), label: ()}, types: {coarse_label: tf.int64, image: tf.uint8, label: tf.int64}> which is a `tf.data.dataset`

这个数据集包含 100 个类。假设我还有一个名为的列表our_index,它有 20 个不同的元素,每个元素代表一个类。我想做的是遍历 ds_train 数据集并只保留属于这 20 个类之一的元素。为此,我想我可以使用这个:[ https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter][1]

但我不确定如何。有什么想法吗?

标签: tensorflowmachine-learningkeras

解决方案


使用我在上面的评论中为您提供的链接的答案,我可以过滤数据集以仅包含标签 0、1 和 2,如下所示:

import tensorflow_datasets as tfds
import tensorflow as tf

def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
    label = x['label']
    isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
    reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
    return tf.greater(reduced, tf.constant(0.))

cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")

filtered_ds_train=ds_train.filter(predicate)
filtered_ds_test=ds_test.filter(predicate)

现在迭代并打印 filters_ds_train 的标签我们可以看到只选择了3 个标签。

for x in myclasses:
  print(x['label'])

您可以更改allowed_labels=tf.constant([0., 1., 2.])参数以包含其他类标签。目前它选择标签 0、1 和 2。


推荐阅读