tensorflow - KeyError:'未能格式化此回调文件路径:原因:\'lr\''
问题描述
我最近从 Tensorflow 2.2.0 切换到 2.4.1,现在我的ModelCheckpoint
回调路径有问题。如果我使用具有 tf 2.2 的环境,则此代码可以正常工作,但在使用 tf 2.4.1 时会出错。
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint])
错误:
KeyError:'无法格式化此回调文件路径:“path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}”。原因:'lr''
解决方案
中ModelCheckpoint
,参数的格式化名称filepath
,只能包含:epoch
+logs
纪元结束后的键。
您可以在日志中看到可用的密钥,如下所示:
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("Log keys: {}".format(keys))
model.fit(..., callbacks=[CustomCallback()])
如果您运行上面的代码,您将看到如下内容:
Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
这显示了您可以使用的可用键(加号epoch
)和lr
不可用的键(您使用了 3 个键:epoch
和lr
in val_loss
name filepath
)。
解决方案:
您可以自己将学习率添加到日志中:
import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
logs.update({'lr': K.eval(self.model.optimizer.lr)})
keys = list(logs.keys())
print("Log keys: {}".format(keys)) #you will see now `lr` available
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint, CustomCallback()])
推荐阅读
- ios - 为什么适用于 iOS 的 ionic cordova 构建失败并出现错误
- velocity - 速度模板引擎中的添加
- python - pyodbc DSN - 连接到连接字符串中没有 UID 和 PWD 的 SQL Server
- javascript - 使用带有js的按钮获取while循环的值
- arraylist - 如何在 kotlin 中获取 ArrayList 的最后一个值?
- flutter - 颤振检测屏幕是否关闭
- sql - 如何进行 sql 数据透视
- java - 在 Spring 中将 Jade4J 配置为默认模板引擎
- javascript - 查找表中最大的数字
- shell - 在 groovy 中执行 shell 命令会引发“意外字符”错误