python - 如何在 Tensorflow 中设置 ParallelMapDataset 数据类型中的图像数量?
问题描述
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
train_images = dataset['train']
test_images = dataset['test']
train_batches = (
train_images
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)
现在我想将 test_images 的大小减少到 100 张图像。我期待一些代码,如:
test_images = test_images[100]
但这会产生错误:
'ParallelMapDataset' object is not subscriptable
解决方案
使用take()
方法,您可以从目标数据集中获取批次或项目。
如果数据集是批处理的:
test_images.take((100 // BATCH_SIZE) + 1)
当您对数据集进行批处理时,它将包含批次或组。
因此,假设您使用大小为 32 的数据批量处理,test_images.take(1)
将返回 32 个元素,即单个批次。test_images.take(2)
将返回 64 个元素等。
如果没有批处理:
test_images.take(100)
与批处理数据集不同,数据集将返回已传递给take()
方法的元素数量。