首页 > 解决方案 > 使用 Keras model.fit,如何设置它以保存每 x 个步骤?

问题描述

我想在运行 model.fit 时每隔 x 步保存我的模型。

我正在查看文档 https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit

而且似乎没有选择。但是在训练期间保存检查点是一个很常见的用例,很难想象没有办法做到这一点。所以我想知道我是否忽略了一些东西。

标签: tensorflowkerastensorflow2.0tf.keras

解决方案


这可以使用ModelCheckpoint 回调来完成:

EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

monitor您可以使用、mode和参数修改回调的行为,这些save_best_only参数控制要跟踪的指标以及检查点是否被覆盖以仅保留最佳模型。


推荐阅读