首页 > 解决方案 > 用于动态提取补丁和展平数据集的 TF 管道

问题描述

我要在图像补丁上训练一个自动编码器。我的训练数据由加载到形状为 numpy 的数组中的单通道图像组成[10000, 256, 512, 1]。我知道如何从图像中提取补丁,但是批次选择图像是相当不直观的,因此每批次的点数取决于每个图像提取了多少补丁。如果每个图像提取 32 个补丁,我希望数据集表现得好像它是[320000, 256, 512, 1]这样的,以便一次从多个图像中提取洗牌和批次,但动态提取补丁,这样就不必保留在记忆中。

我见过的最接近的问题是加载 tensorflow 图像并创建补丁,但正如我所提到的,它不能提供我想要的。

PATCH_SIZE = 64

def extract_patches(imgs, patch_size=PATCH_SIZE, stride=PATCH_SIZE//2):
    # extract patches and reshape them into patch images
    n_channels = imgs.shape[-1]
    if len(imgs.shape) < 4:
        imgs = tf.expand_dims(imgs, axis=0)  
    return tf.reshape(tf.image.extract_patches(imgs,
                                               sizes=[1, patch_size, patch_size, n_channels],
                                               strides=[1, stride, stride, n_channels],
                                               rates=[1, 1, 1, 1],
                                               padding='VALID'),
                      (-1, patch_size, patch_size, n_channels))

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            )

创建一个返回具有形状的批次的数据集,(batch_size, 105, 64, 64, 1)而我想要一个具有形状和随机播放的等级 4 张量(batch_size, 64, 64, 1)来对补丁进行操作(而不是每个图像的补丁集合)。如果我放在.map管道的末端

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            )

这确实会使批次变平并返回一个等级为 4 的张量,但在这种情况下,每个批次都有 shape (840, 64, 64, 1)

标签: pythontensorflowtensorflow-datasets

解决方案


推荐阅读