首页 > 解决方案 > 如果 sparse_validation_accuracy 高于基线,如何停止训练?

问题描述

我正在拟合一个模型并尝试在95% 以上huggingface时设置一个提前停止。sparse_validation_accuracy

我正在使用以下调用:

early_stopper = tf.keras.callbacks.EarlyStopping(monitor='accuracy', 
                                                 baseline = 0.90,
                                                 patience = 0,
                                                 restore_best_weights=True)

# train the model 
model.fit(train_dataset.shuffle(len(x_train)).batch(BATCH_SIZE),
          epochs=N_EPOCHS,
          batch_size=BATCH_SIZE,
          callbacks = [early_stopper])

不幸的是,模型一直在训练,如下所示。我错过了什么吗?

0.9688WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.6875s vs `on_train_batch_end` time: 1.1250s). Check your callbacks.
  73/7495 [..............................] - ETA: 3:43:56 - loss: 0.1147 - sparse_categorical_accuracy: 0.9546

标签: pythontensorflow

解决方案


实际上,内置EarlyStopping回调仅在 epoch 结束时起作用。因此,它不会在一个时期的中间停止你的训练。如果您想要一个回调,这将在 epoch 尚未结束时停止训练,请尝试将您的自定义回调创建为tf.keras.callbacks.Callback. 您将需要覆盖该on_train_batch_end方法。

您生成的回调可能是这样的:

class CustomEarlyStopping(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, logs=None):
        if logs['sparse_categorical_accuracy'] > 0.95:
            self.model.stop_training = True

我已经有一段时间没有练习 TF 了,所以这段代码可能无法开箱即用,但它是可以开始的。更多信息可以在关于编写自定义回调回调类参考的官方文档中找到。


推荐阅读