首页 > 解决方案 > 在 Keras 中使用 TF Dataset API 的既定方法是使用 `make_one_shot_iterator()` 提供 `model.fit`,但这个迭代器只适用于一个 Epoch

问题描述

编辑:

为了澄清为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,Keras 究竟在用这些 SO 问题中描述的技术做了什么。建议的重复项使用数据集 API 指定make_one_shot_iterator()model.fit我的后续行动是make_one_shot_iterator()只能通过数据集一次,但是在给出的解决方案中,指定了几个时期。


这是对这些 SO 问题的跟进

如何正确结合 TensorFlow 的 Dataset API 和 Keras?

带有 tf 数据集输入的 TensorFlow keras

使用 tf.data.Dataset 作为 Keras 模型的训练输入不起作用

其中“从 Tensorflow 1.9 开始,可以将 tf.data.Dataset 对象直接传递给 keras.Model.fit(),它的行为类似于 fit_generator”。每个示例都有一个 TF 数据集 one shot iterator 输入到 Kera 的 model.fit 中。

下面给出一个例子

# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here              
model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=len(x_train) // 128,
    epochs=5,
    verbose = 1)

但是,根据 Tensorflow 数据集 API 指南(此处为https://www.tensorflow.org/guide/datasets):

one-shot 迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代

所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。

这个矛盾有什么解释吗?Keras 是否知道当单次迭代器遍历数据集时,它可以重新初始化和打乱数据?

标签: tensorflowkerastensorflow-datasetstf.keras

解决方案


您可以简单地将数据集对象传递给model.fit,Keras 将处理迭代。考虑一个预制数据集:

train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))

这将从 cifar10 数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。如果您从包含 numpy 数组列表图像的路径创建数据集,您将需要一个。

dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path)) 

如果您需要一个函数来从文件名加载实际数据。Numpy 数组可以用同样的方式处理,只是不需要tf.read_file

def parse_func(filename):
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)
    label = #get label from filename
    return image, label

然后,您可以对任何解析函数进行混洗、批处理和映射到该数据集。您可以控制使用随机缓冲区预加载的示例数量。重复控制 epoch 计数,最好保留 None,因此它将无限重复。您可以使用普通批处理功能或结合使用

dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))

然后可以将数据集对象传递给model.fit model.fit(dataset, epochs, steps_per_epoch)。请注意,steps_per_epoch在这种情况下这是一个必要的参数,它将定义何时开始新的纪元。所以你必须提前知道纪元大小。


推荐阅读