首页 > 解决方案 > 如何在 Tensorflow 中设置 ParallelMapDataset 数据类型中的图像数量?

问题描述

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

train_images = dataset['train']

test_images = dataset['test']

train_batches = ( 
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

现在我想将 test_images 的大小减少到 100 张图像。我期待一些代码,如:

test_images = test_images[100]

但这会产生错误:

'ParallelMapDataset' object is not subscriptable

标签: pythontensorflowkerasdata-structuresdeep-learning

解决方案


使用take()方法,您可以从目标数据集中获取批次或项目。

如果数据集是批处理的:

test_images.take((100 // BATCH_SIZE) + 1)

当您对数据集进行批处理时,它将包含批次或组。

因此,假设您使用大小为 32 的数据批量处理,test_images.take(1)将返回 32 个元素,即单个批次。test_images.take(2)将返回 64 个元素等。


如果没有批处理:

test_images.take(100)

与批处理数据集不同,数据集将返回已传递给take()方法的元素数量。


推荐阅读