首页 > 解决方案 > KeyError:'未能格式化此回调文件路径:原因:\'lr\''

问题描述

我最近从 Tensorflow 2.2.0 切换到 2.4.1,现在我的ModelCheckpoint回调路径有问题。如果我使用具有 tf 2.2 的环境,则此代码可以正常工作,但在使用 tf 2.4.1 时会出错。

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint])

错误:

KeyError:'无法格式化此回调文件路径:“path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}”。原因:'lr''

标签: tensorflowkerascallback

解决方案


ModelCheckpoint,参数的格式化名称filepath,只能包含:epoch+logs纪元结束后的键。

您可以在日志中看到可用的密钥,如下所示:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Log keys: {}".format(keys))

model.fit(..., callbacks=[CustomCallback()])

如果您运行上面的代码,您将看到如下内容:

Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']

这显示了您可以使用的可用键(加号epoch)和lr不可用的键(您使用了 3 个键:epochlrin val_lossname filepath)。


解决方案:

您可以自己将学习率添加到日志中:

import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs.update({'lr': K.eval(self.model.optimizer.lr)})
        keys = list(logs.keys())
        print("Log keys: {}".format(keys)) #you will see now `lr` available

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint, CustomCallback()])

推荐阅读