首页 > 解决方案 > 如何使用 TensorFlow 的 Keras API 为每个 epoch 的保存模型生成唯一名称

问题描述

我正在训练模型fit_generator()并希望为每个时代保存的 wights 生成唯一的名称

已经尝试过:查看后面的代码

代码:

model_path = '.\checkpoints\cp{}.ckpt'.format(time())
cp_callback = tf.keras.callbacks.ModelCheckpoint(model_path, 
                                                 verbose=1,
                                                  period=2)
model.fit_generator(..........,callbacks=[cp_callback])

预期:生成唯一的检查点名称
,例如 epoch_4.ckpt 或 epoch_5.ckpt
实际:每次保存时,覆盖现有检查点

标签: tensorflowtf.keras

解决方案


您可以尝试将 epoch 变量添加到文件路径变量

filepath = ".\checkpoints\cp-{epoch:02d}.hdf5"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath, verbose=1, period=2)

这里的 Keras 文档已经提到了这一点。


推荐阅读