首页 > 解决方案 > 来自 avro 文件的 tf.data.dataset

问题描述

我正在尝试使用tf.data.Datasetand并行化我的输入管道TFRecordDataset

files = tf.data.Dataset.list_files("./data/*.avro")
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=16)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
    preprocess_fn, 512, num_parallel_batches=16) )

preprocess_fn如果输入是 AVRO 文件(类似于 JSON),我不确定如何编写。


目前,我正在使用并提供由或类似的 avro 阅读器tf.data.Dataset.from_generator解析的 avro 记录。pyavroc但我不确定如何并行化它,因为from_generator方法没有可用的num_parallel_reads选项。

def gen():
    for file in all_avro_files:
        x, y = read_local_avro_data(file)
        for i, sample in enumerate( x ):
            yield sample, y[i]

dataset = tf.data.Dataset.from_generator( gen, 
            (tf.float32, tf.float64),
            ( tf.TensorShape([13000]), tf.TensorShape([]) 
        ) 
    )

逐个文件读取显然是一个瓶颈,我看到所有内核在用完前一批数据后都在等待数据。

如何优化这两种方法?

标签: pythontensorflowkerastensorflow-datasetstfrecord

解决方案


推荐阅读