首页 > 解决方案 > 使用 TensorFlow Data API 和 Keras 模型预测批次

问题描述

假设我有一个数据集和一个 Keras 模型。batch()数据集已使用tf Dataset API分为批次。现在我正在寻找一种有效且干净的方法来对所有测试样本进行批量预测。

我已经尝试了以下代码并且它有效。

batch_size = 32
dataset = dataset.batch(batch_size)
predictions = keras_model.predict(dataset, steps=math.ceil(num_testing_samples / batch_size))

我想知道有没有更有效和优雅的方法来实现这一点?

标签: tensorflowkerastensorflow-datasets

解决方案


TF >= 1.14.0

你可以只设置steps=None. 来自官方文档tf.keras.Model.predict()

如果 x 是一个 tf.data 数据集并且 steps 是 None,predict 将运行直到输入数据集用完。

只需确保您的dataset对象未处于重复模式,您就可以开始了 :)。

TF 1.12.0 和 1.13.0

tf.data.Dataset这些版本对with的支持tf.keras很差。对象在此处tf.data.Dataset转换为迭代器,如果您未设置 参数,则会在此处触发错误。这是在 1.14.0 中修补的。steps


推荐阅读