首页 > 解决方案 > 如何使用 on_batch_end 回调手动停止模型训练

问题描述

如何通过回调提前停止训练模型on_batch_end?我尝试将model.stop_training属性设置为,True但它似乎不起作用。

这是我使用的代码:

callback = LambdaCallback(on_batch_end=lambda batch,logs:self.on_batch_end(batch, logs))
self.model.fit(
          x=trainData,
          steps_per_epoch=stepsPerEpoch,
          epochs=epochs,
          verbose=verbose,
          callbacks=[callback])

这是我的 on_batch_end 方法。请记住,我多次进入它并且它永远不会结束训练过程。

def on_batch_end(self, batch, logs):
    # grab the current learning rate and add log it to the list of
    # learning rates that we've tried
    lr = K.get_value(self.model.optimizer.lr)
    self.lrs.append(lr)

    # grab the loss at the end of this batch, increment the total
    # number of batches processed, compute the average average
    # loss, smooth it, and update the losses list with the
    # smoothed value
    l = logs["loss"]
    self.batchNum += 1
    self.avgLoss = (self.beta * self.avgLoss) + ((1 - self.beta) * l)
    smooth = self.avgLoss / (1 - (self.beta ** self.batchNum))
    self.losses.append(smooth)

    # compute the maximum loss stopping factor value
    stopLoss = self.stopFactor * self.bestLoss
    print("\n[INFO]: Comparing Smooth Loss {} and Stop Loss {}".format(smooth, stopLoss))
    # check to see whether the loss has grown too large
    if self.batchNum > 1 and smooth > stopLoss:
      # stop returning and return from the method
      print("[INFO]: Loss is too high. Stopping training!")
      self.model.stop_training = True
      return

    # check to see if the best loss should be updated
    if self.batchNum == 1 or smooth < self.bestLoss:
      self.bestLoss = smooth
      
    # increase the learning rate
    lr *= self.lrMult
    K.set_value(self.model.optimizer.lr, lr)

这是训练的输出:

[INFO]: Loss is too high. Stopping training!
524/535 [============================>.] - ETA: 0s - loss: 19639.2344 - binary_accuracy: 0.5551
[INFO]: Comparing Smooth Loss 10783.845046550889 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
525/535 [============================>.] - ETA: 0s - loss: 19726.4941 - binary_accuracy: 0.5555
[INFO]: Comparing Smooth Loss 10962.001075307371 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!

[INFO]: Comparing Smooth Loss 11144.855858488723 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
527/535 [============================>.] - ETA: 0s - loss: 20104.7402 - binary_accuracy: 0.5560
[INFO]: Comparing Smooth Loss 11329.031436631449 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!

标签: tensorflowkerasdeep-learning

解决方案


我有同样的问题,似乎 keras 只会在一个时代结束时中断训练。如果您self.model.stop_training = True在批次后设置回调,则培训将继续以下批次,直到 epoch 结束,然后才停止。

我发现的一种解决方案是step_per_epoch在调用tf.keras.Model.fit. 使用更短的时期,您可以更好地控制停止条件。


推荐阅读