tensorflow - 减少数据集的类
问题描述
假设我有这样初始化的 CIFAR-100(images) 数据集:
cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")
例如 ds_train 是一个类型的对象:
<DatasetV1Adapter shapes: {coarse_label: (), image: (32, 32, 3), label: ()}, types: {coarse_label: tf.int64, image: tf.uint8, label: tf.int64}> which is a `tf.data.dataset`
这个数据集包含 100 个类。假设我还有一个名为的列表our_index
,它有 20 个不同的元素,每个元素代表一个类。我想做的是遍历 ds_train 数据集并只保留属于这 20 个类之一的元素。为此,我想我可以使用这个:[ https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter][1]。
但我不确定如何。有什么想法吗?
解决方案
使用我在上面的评论中为您提供的链接的答案,我可以过滤数据集以仅包含标签 0、1 和 2,如下所示:
import tensorflow_datasets as tfds
import tensorflow as tf
def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))
cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")
filtered_ds_train=ds_train.filter(predicate)
filtered_ds_test=ds_test.filter(predicate)
现在迭代并打印 filters_ds_train 的标签,我们可以看到只选择了3 个标签。
for x in myclasses:
print(x['label'])
您可以更改allowed_labels=tf.constant([0., 1., 2.])参数以包含其他类标签。目前它选择标签 0、1 和 2。
推荐阅读
- java - 应用程序实例之间的 Java 同步
- javascript - Javascript 将大字符串拆分成带有引号和换行符的小部分
- javascript - 这是改变反应状态的可接受方式吗?
- python - 嵌套元组......将两个元组相乘并存储在第三个......怎么做......?
- excel - 如何将数据框附加到Excel工作表中
- python-3.x - Python Keras LSTM 'ValueError: `start_index+length='
- flutter - 键盘出现时将文本字段滚动到视图中 Flutter
- docker - Docker 客户端 API 版本问题
- c# - 从代码中突出显示 ListView 项 - WPF 数据绑定
- javascript - 在每个固定小时后 1 分钟刷新的脚本