首页 > 解决方案 > 张量流 2.3 中的回调

问题描述

我正在编写自己的回调以根据某些自定义条件停止训练。一旦满足条件,EarlyStopping 就会停止训练:

self.model.stop_training = True

例如来自https://www.tensorflow.org/guide/keras/custom_callback

class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """当损失达到最小值时停止训练,即损失停止减少。

参数: 耐心:达到 min 后等待的 epoch 数。在这个数量没有改善之后,训练就停止了。"""

def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()
    self.patience = patience
    # best_weights to store the weights at which the minimum loss occurs.
    self.best_weights = None

def on_train_begin(self, logs=None):
    # The number of epoch it has waited when loss is no longer minimum.
    self.wait = 0
    # The epoch the training stops at.
    self.stopped_epoch = 0
    # Initialize the best as infinity.
    self.best = np.Inf

def on_epoch_end(self, epoch, logs=None):
    current = logs.get("loss")
    if np.less(current, self.best):
        self.best = current
        self.wait = 0
        # Record the best weights if current results is better (less).
        self.best_weights = self.model.get_weights()
    else:
        self.wait += 1
        if self.wait >= self.patience:
            self.stopped_epoch = epoch
            self.model.stop_training = True
            print("Restoring model weights from the end of the best epoch.")
            self.model.set_weights(self.best_weights)

def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
        print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

问题是,它不适用于 tensorflow 2.2 和 2.3。任何解决方法的想法?还有什么办法可以停止在 tf 2.3 中训练模型?

标签: tensorflowkerascallbackearly-stopping

解决方案


我复制了您的代码并添加了一些打印语句以查看发生了什么。我还将被监控的损失从训练损失更改为验证损失,因为训练损失往往会在许多时期内不断减少,而验证损失往往会更快地趋于平稳。最好监控验证损失以提前停止和节省权重,然后使用训练损失。您的代码运行良好,并且如果在耐心的 epoch 数后损失没有减少,则停止训练。确保你有下面的代码

patience=3 # set patience value
callbacks=[EarlyStoppingAtMinLoss(patience)]
# in model.fit include callbacks=callbacks

这是您使用打印语句修改的代码,因此您可以看到发生了什么

class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("val_loss")
        print('epoch = ', epoch +1, '   loss= ', current, '   best_loss = ', self.best, '   wait = ', self.wait)
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            print ( ' loss improved setting wait to zero and saving weights')
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            print ( ' for epoch ', epoch +1, '  loss did not improve setting wait to ', self.wait)
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

我复制了你的新代码并运行了它。显然 tensorflow 不会在批处理期间评估 model.stop_training。因此,即使 model.stop_training 在 on_train_batch_end 中设置为 True,它也会继续处理批次,直到该时期的所有批次都完成。然后在 epoch 结束时 tensorflow 评估 model.stop_training 并且训练确实停止了。


推荐阅读