首页 > 解决方案 > 优化和批处理字节文件数据到 keras 模型中

问题描述

我正在尝试为我的 model.fit 提供分块数据,因为我的整个数据集不适合我的记忆。TF版本是2.4。

我已经优化了存储在文件中的数据集,将其保存为字节,每个列是 1 个字节(0-255),但现在我需要以块的形式读取它并开始我的网络的学习过程。

在此之前,我使用 .csv,使用 numpy reshape 加载数据集,然后将整个数据集提供给 keras 模型。

我应该使用从以前的迭代.fit中加载的每个块吗?load_weight

标签: pythontensorflowkerasdataset

解决方案


file成为存储数据的文件。要创建数据集,首先需要创建一个generator用于读取数据的数据集。从原理上讲,它应该具有以下结构:

    def generator(file):
        with open(file,'r') as f:
            for linne in f:
                ### code for generating single sample and, optionally, label ###
                yield (sample,label)

无论如何,您都不要发出文件结束的信号。当generator被调用时,它会以pairs 的形式产生所有数据(sample,label)。要检查是否generator正确生成数据,可以执行类似的操作

    for (sample,label) in generator(file):
        ### examine your samples and labels

如果为了生成单个样本,您需要处理多于一行file,则应相应地修改代码(或文件)。

一旦你有一个工作生成器,你可以创建一个dataset

    dataset = tf.data.Dataset.from_generator( generator(file), 
                output_signature=( 
                    ( tf.TensorSpec(shape=(shape-of-your-data)),
                      tf.TensorSpec(shape=(shape-of-your-labels)) ) )

您可以从不同的文件中读取样本和标签。如果dataset_Adataset_B分别包含样本和标签,则需要使用zip数据集的方法以训练对的形式获取数据集生成数据。

此外,您需要从单个样品中批量生产:

    dataset = dataset.batch(32)

为了改进准备数据的管道,数据集的良好做法是prefetch

    dataset = dataset.prefetch(number-of-batches-to-prefetch)

当当前数据通过网络传递时,这会“预置”数据,有关详细信息,请参阅本教程(我不建议将tf.data.AUTOTUNE其用作参数prefetch;通常手动设置参数会产生更好的性能)。

结果dataset可用于拟合模型model.fit。请注意,对于数据集,shuffle参数 ofmodel.fit被忽略。要打乱您的数据,您需要shuffle使用datasets.

h5就个人而言,我发现它有用地从文件中生成数据。在这种情况下,您只需要像在 numpy 数组中那样迭代样本,逐行读取文件。这更灵活(例如,您可以直接在生成器中打乱数据)并且在读取数据方面更方便。


推荐阅读