首页 > 解决方案 > 保存没有自定义对象的完整 tf.keras 模型?

问题描述

假设我们有一些模型,包括在训练期间很重要的自定义损失和指标。是否可以保存完整的模型,所以 weights + graphdef / pb-file,没有自定义对象?

在推理过程中,不需要自定义损失和指标,因此......

tf.keras.models.load_model("some_model", custom_objects={...})

...只会使推理代码更加复杂,因为需要包含自定义目标代码以进行推理(尽管未使用它)。

但是,tf.keras.callbacks.ModelCheckpoint(即使使用include_optimizer=False)以及调用model.save()总是保存模型定义,包括自定义对象。

因此,只需加载模型...

tf.keras.models.load_model("some_model")

...总是会失败并抱怨缺少自定义对象。

是否可以在没有自定义损失/指标的情况下以某种方式保存整个模型?要获得易于加载的网络“推理”版本?

或者是将所有内容冻结为 TFLite 模型的唯一解决方案?

当然,也可以简单地使用model.save_weights(),但随后需要包含实际代码以供稍后推断,这是不希望的。

标签: pythontensorflowkerassaverestore

解决方案


如果目的是防止加载损失和指标,您可以使用以下compile参数load_model

model = tf.keras.models.load_model("some_model", compile=False)

这应该跳过损失和指标/优化器的要求,因为模型没有被编译。当然你现在不能训练模型,但它应该可以很好地使用推理model.predict()


推荐阅读