首页 > 解决方案 > 如何在 Tensorflow 中有效地加入来自 TFRecords 的数据

问题描述

TFRecord在s上训练 TensorFlow 模型时,我需要有效地加入少量数据。如何使用已解析的信息进行此查找TFRecord

更多细节:

我正在使用TFRecords. 每个都TFRecord包含原始图像以及目标标签,以及有关图像的一些元数据。训练的一部分是我需要使用特定于一组图像的mean和标准化图像。std过去为了做到这一点,我将meanand硬编码stdTFRecord. 然后在 myparse_example中使用它来映射Datasetmy 中的input_fn,如下所示:

def parse_example(..):
    # ...
    parsed = tf.parse_single_example(value, keys_to_features)
    image_raw = tf.decode_raw(parsed['image/raw'], tf.uint16)
    image = tf.reshape(image_raw, image_shape)
    image.set_shape(image_shape)

    # pull hardcoded pixels mean and std from the parsed TFExample
    mean = parsed['mean']
    std = parsed['std']

    image = (tf.cast(image, tf.float32) - mean) / std

    # ...

    return image, label

虽然上述方法有效并且可以缩短训练时间,但它的局限性在于我经常想改变我使用的mean东西。我宁愿在训练时查找适当的汇总统计数据,std而不是将meanandstd写入s。TFRecord这意味着当我训练时,我有一个小的 Python 字典,我可以使用从TFRecord. 我遇到的问题是我似乎无法在我的张量流图中使用这个 python 字典。如果我尝试直接进行查找,它不起作用,因为我有张量对象而不是实际的原语。这是有道理的input_fn正在为 TensorFlow 构建计算图的符号操作(对吗?)。我该如何解决这个问题?

我尝试过的一件事是从字典中创建一个查找表,如下所示:

def create_channel_hashtable(keys, values, default_val=-1):
    initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
    return tf.contrib.lookup.HashTable(initializer, default_val)

可以在parse_example函数中创建和使用哈希表来进行查找。这一切都“有效”,但它极大地减慢了训练速度。值得注意的是,这种培训是在 TPU 上进行的。使用来自TFRecords 的值的原始方法,训练速度非常快,并且不受 IO 的限制,但是当使用哈希查找时,这种情况会发生变化。处理这些情况的建议方法是什么?虽然重新打包TFRecords 是可行的,但当要查找的数据很小并且可以提高效率时,这似乎很愚蠢。

标签: pythontensorflowtfrecordtpu

解决方案


这个问题涵盖了这个主题:

如何将多个 tfrecords 文件合并到一个文件中?

似乎您会将 TFRecords 保存到文件中,然后使用 TFRecordDataset 将它们全部拉出到一个数据集中。我链接的上述问题的答案中给出的代码是:

dataset = tf.data.TFRecordDataset(filenames_to_read,
    compression_type=None,    # or 'GZIP', 'ZLIB' if compress you data.
    buffer_size=10240,        # any buffer size you want or 0 means no buffering
    num_parallel_reads=os.cpu_count()  # or 0 means sequentially reading
) 

推荐阅读