tensorflow - 在 Keras 中使用 TF Dataset API 的既定方法是使用 `make_one_shot_iterator()` 提供 `model.fit`,但这个迭代器只适用于一个 Epoch
问题描述
编辑:
为了澄清为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,Keras 究竟在用这些 SO 问题中描述的技术做了什么。建议的重复项使用数据集 API 指定make_one_shot_iterator()
,model.fit
我的后续行动是make_one_shot_iterator()
只能通过数据集一次,但是在给出的解决方案中,指定了几个时期。
这是对这些 SO 问题的跟进
如何正确结合 TensorFlow 的 Dataset API 和 Keras?
使用 tf.data.Dataset 作为 Keras 模型的训练输入不起作用
其中“从 Tensorflow 1.9 开始,可以将 tf.data.Dataset 对象直接传递给 keras.Model.fit(),它的行为类似于 fit_generator”。每个示例都有一个 TF 数据集 one shot iterator 输入到 Kera 的 model.fit 中。
下面给出一个例子
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)
model = # your keras model here
model.fit(
training_set.make_one_shot_iterator(),
steps_per_epoch=len(x_train) // 128,
epochs=5,
verbose = 1)
但是,根据 Tensorflow 数据集 API 指南(此处为https://www.tensorflow.org/guide/datasets):
one-shot 迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代
所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。
这个矛盾有什么解释吗?Keras 是否知道当单次迭代器遍历数据集时,它可以重新初始化和打乱数据?
解决方案
您可以简单地将数据集对象传递给model.fit
,Keras 将处理迭代。考虑一个预制数据集:
train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
这将从 cifar10 数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。如果您从包含 numpy 数组列表图像的路径创建数据集,您将需要一个。
dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path))
如果您需要一个函数来从文件名加载实际数据。Numpy 数组可以用同样的方式处理,只是不需要tf.read_file
def parse_func(filename):
f = tf.read_file(filename)
image = tf.image.decode_image(f)
label = #get label from filename
return image, label
然后,您可以对任何解析函数进行混洗、批处理和映射到该数据集。您可以控制使用随机缓冲区预加载的示例数量。重复控制 epoch 计数,最好保留 None,因此它将无限重复。您可以使用普通批处理功能或结合使用
dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))
然后可以将数据集对象传递给model.fit
model.fit(dataset, epochs, steps_per_epoch)。请注意,steps_per_epoch
在这种情况下这是一个必要的参数,它将定义何时开始新的纪元。所以你必须提前知道纪元大小。
推荐阅读
- jquery - 类型错误 undefined thumnail_path 未定义
- python - 如何在 Windows 系统上执行 Kaggle Api 命令?
- java - 如何将arraylist作为整数返回
- machine-learning - 尽管种子值不同,但 Weka 的预测结果一致
- javascript - 如何解决运行时错误:this.initializationApp 不是函数(对不起英语,我是巴西人)
- c# - SQL 更新语句显示 OldShippedDate == NewShippedDate(当我更改新的发货日期时)以尝试更新。抛出sql并发错误
- java - 抛出 catch 块是一种好的编码习惯吗?
- tomcat - 8080 或 8081 是 tomcat 或者我们可以将它用于任何应用程序
- swift - UIStackView 在 IB 中创建一个有间隙的表
- node.js - 使用 nodejs 更改 MongoDB gridfs 中的块大小