首页 > 解决方案 > TensorFlow Dataset.flat_map() 导致批次不均

问题描述

从每个包含 10 个样本的文件中加载我的输入,并用每批 4 个样本对它们进行批处理,我得到大小不均匀的批次 4、4、2、4、4、2 等,而不是像我一样将连续文件中的样本组合在一起的批次在展平数据集后期望。

我正在使用 TensorFlow 1.8.0。为了从文件中获取数据到我的 Dataset 对象中,我遵循了这个答案。我的输入管道如下所示:

# Initialize dataset on files
dataset = tf.data.Dataset.list_files(input_files_list)

# Pre-process data in parallel
def preprocess_fn(input_file):
    # lots of logic here...
    return input1, input2, input3

map_fn = lambda input_file: tf.py_func(
    preprocess_fn, [input_file], [tf.float32, tf.float32, tf.float32])
dataset = dataset.map(map_func=map_fn, num_parallel_calls=4)

# Flatten from files to samples
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))

dataset = dataset.batch(batch_size=4)
dataset = dataset.prefetch(buffer_size=8)

但我看到的是样本实际上并没有在输入文件之间连接,因此批量大小是不均匀的。我认为这是因为flat_map()将每个元素(文件中的所有输入样本)映射到数据集 - 所以在flat_map()我的数据集实际上是数据集的数据集之后,每个嵌套数据集都是单独批处理的。

但这不是我的本意。如何连接嵌套数据集,或以其他方式展平数据集,以便可以将来自不同文件的样本批处理在一起?

标签: pythontensorflowtensorflow-datasets

解决方案


我在 TF 2.0 中遇到了类似的问题,并且我使用了unbatch函数,如下所示:

dataset = dataset.flat_map(lambda f: parse_function(f)).apply(tf.data.experimental.unbatch())

我相信可以使用TF 1.8 tf.contrib.data.unbatch 。


推荐阅读