tensorflow - 张量流 2.3 中的回调
问题描述
我正在编写自己的回调以根据某些自定义条件停止训练。一旦满足条件,EarlyStopping 就会停止训练:
self.model.stop_training = True
例如来自https://www.tensorflow.org/guide/keras/custom_callback
class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """当损失达到最小值时停止训练,即损失停止减少。
参数: 耐心:达到 min 后等待的 epoch 数。在这个数量没有改善之后,训练就停止了。"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
问题是,它不适用于 tensorflow 2.2 和 2.3。任何解决方法的想法?还有什么办法可以停止在 tf 2.3 中训练模型?
解决方案
我复制了您的代码并添加了一些打印语句以查看发生了什么。我还将被监控的损失从训练损失更改为验证损失,因为训练损失往往会在许多时期内不断减少,而验证损失往往会更快地趋于平稳。最好监控验证损失以提前停止和节省权重,然后使用训练损失。您的代码运行良好,并且如果在耐心的 epoch 数后损失没有减少,则停止训练。确保你有下面的代码
patience=3 # set patience value
callbacks=[EarlyStoppingAtMinLoss(patience)]
# in model.fit include callbacks=callbacks
这是您使用打印语句修改的代码,因此您可以看到发生了什么
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("val_loss")
print('epoch = ', epoch +1, ' loss= ', current, ' best_loss = ', self.best, ' wait = ', self.wait)
if np.less(current, self.best):
self.best = current
self.wait = 0
print ( ' loss improved setting wait to zero and saving weights')
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
print ( ' for epoch ', epoch +1, ' loss did not improve setting wait to ', self.wait)
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
我复制了你的新代码并运行了它。显然 tensorflow 不会在批处理期间评估 model.stop_training。因此,即使 model.stop_training 在 on_train_batch_end 中设置为 True,它也会继续处理批次,直到该时期的所有批次都完成。然后在 epoch 结束时 tensorflow 评估 model.stop_training 并且训练确实停止了。
推荐阅读
- vim - 如何使用 ctags 和 fzf 配置 VIM 代码完成
- pagespeed-insights - 来源摘要或实验室数据?
- vuejs3 - 将 Bootstrap 5 与 Vue 3 一起使用
- reactjs - 将 this.setstate 与 Fetch API 一起使用
- php - PHP DomDocument 获取锚标记href 和内部html?
- javascript - 如何在使用 window.location.replace() 创建重定向路径时防止 XSS?
- javascript - JSON stringify 在整个 json 对象周围添加引号并转义字符
- flutter - 应用程序构建后,google_fonts 无法正常工作
- clang - clang11 Segfaulting on name mangling
- bash - 如何将 java 文件的输出存储在 .txt 中,而不是显示在终端中