python - 如何保存一些最好的 Keras 模型?
问题描述
我需要保存一些最好的模型。我从 Keras 定义了 ModelCheckpoint,它只保存了一个最佳模型。
checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor='val_loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='min'
)
有任何想法吗?
解决方案
save_best_only=False
这样,您将在每个 epoch 之后保存模型,而不仅仅是当模型被认为是“最好的”时。
请记住为每个时期定义更改的路径,否则模型将被覆盖。
评论后:
如果你的回调只保存一个模型,可能是因为你已经定义了静态路径,你需要使用 keras 格式,以便保存模型的文件上的名称在每个 epoch 都会改变,否则会被覆盖。
您需要以如下方式定义路径:
filepath= model_folder_path + "/epoch:{epoch:02d}-val_loss:{val_loss:.2f}.hdf5"
如果每次 keras 保存新的最佳模型时将路径定义为静态,则之前保存的模型将被覆盖,您将只保留其中一个(最后一个保存的模型)。
请查看文档以获取有关如何编写路径的更多信息:
https://keras.io/api/callbacks/model_checkpoint/
第二次评论后:
好的,现在你想要的是只保存一些最好的模型,而不是每个人,我使用这样的回调实现了这个结果:
class too_many_models(Callback):
def __init__(self, mypath, max_num_of_models):
self.path = mypath
self.mnom = max_num_of_models
def on_epoch_end(self, epoch, logs=None):
onlyfiles = [f for f in listdir(self.path) if isfile(join(self.path, f))]
onlyfiles.sort()
if len(onlyfiles) > self.mnom:
os.remove(join(self.path, onlyfiles[0]))
model_folder = "supertest"
!mkdir "/content/gdrive/My Drive/"$model_folder
folderpath = "/content/gdrive/My Drive/" + model_folder
filepath = folderpath + "/epoch:{epoch:02d}-val_loss:{val_loss:.2f}.hdf5"
my_callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=filepath, save_weights_only=False, monitor='val_loss', mode='min', save_best_only=True),
too_many_models(mypath=folderpath, max_num_of_models=5)
]
你想像这样使用这个回调:
model.fit(train, steps_per_epoch=len(train), validation_data=valid, callbacks=my_callbacks)
需要注意的几点:
在保存模型的回调too_many_models
之后调用回调。
该文件夹需要在开始培训之前为空,每次培训都需要一个不同的文件夹(但我通常这样做,所以我不认为这是一个问题)。
我对文件进行排序,以这种方式我确定我要删除的模型(在索引 0 处)总是保存的旧模型(在纪元中具有较低的数字),这仅是因为我以这种方式定义了模型的名称, 纪元数按我想要的方式对文件进行排序。(我希望这一点很清楚)
您可以使用 更改保存的模型数量max_num_of_models
,使用 5 将保存最后最好的 5 个模型。
folderpath
需要是您保存模型的路径。(注意我如何定义不同的路径)
我只是测试了一切以确保它按预期工作。
我知道这是一个非常粗略的解决方案,但它完成了工作:D
推荐阅读
- python - 将包含字符串数组的 .mat 文件加载到 Python 3.6
- ios - Xcode 11 每次运行前都会将应用重新安装到设备上
- mysql - 如何使列在其他列中具有唯一值
- r - 根据R中另一列中的值范围按列值选择行
- r - 从长格式转换为具有多个唯一变量的宽格式到 R 中的其他唯一变量
- oracle - 从三个不同的列中选择最小值 oracle
- prometheus - 在 PromQL 中将峰值检测为滚动百分比
- mongodb - ObjectID `_id` 是在哪里生成的?
- sql-server - 为什么我的 SQL Server 调用有时会停滞 14.3 秒?
- javascript - 如何在three.js中绝对在相机上翻译Z?