tensorflow - 如何使用 on_batch_end 回调手动停止模型训练
问题描述
如何通过回调提前停止训练模型on_batch_end
?我尝试将model.stop_training
属性设置为,True
但它似乎不起作用。
这是我使用的代码:
callback = LambdaCallback(on_batch_end=lambda batch,logs:self.on_batch_end(batch, logs))
self.model.fit(
x=trainData,
steps_per_epoch=stepsPerEpoch,
epochs=epochs,
verbose=verbose,
callbacks=[callback])
这是我的 on_batch_end 方法。请记住,我多次进入它并且它永远不会结束训练过程。
def on_batch_end(self, batch, logs):
# grab the current learning rate and add log it to the list of
# learning rates that we've tried
lr = K.get_value(self.model.optimizer.lr)
self.lrs.append(lr)
# grab the loss at the end of this batch, increment the total
# number of batches processed, compute the average average
# loss, smooth it, and update the losses list with the
# smoothed value
l = logs["loss"]
self.batchNum += 1
self.avgLoss = (self.beta * self.avgLoss) + ((1 - self.beta) * l)
smooth = self.avgLoss / (1 - (self.beta ** self.batchNum))
self.losses.append(smooth)
# compute the maximum loss stopping factor value
stopLoss = self.stopFactor * self.bestLoss
print("\n[INFO]: Comparing Smooth Loss {} and Stop Loss {}".format(smooth, stopLoss))
# check to see whether the loss has grown too large
if self.batchNum > 1 and smooth > stopLoss:
# stop returning and return from the method
print("[INFO]: Loss is too high. Stopping training!")
self.model.stop_training = True
return
# check to see if the best loss should be updated
if self.batchNum == 1 or smooth < self.bestLoss:
self.bestLoss = smooth
# increase the learning rate
lr *= self.lrMult
K.set_value(self.model.optimizer.lr, lr)
这是训练的输出:
[INFO]: Loss is too high. Stopping training!
524/535 [============================>.] - ETA: 0s - loss: 19639.2344 - binary_accuracy: 0.5551
[INFO]: Comparing Smooth Loss 10783.845046550889 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
525/535 [============================>.] - ETA: 0s - loss: 19726.4941 - binary_accuracy: 0.5555
[INFO]: Comparing Smooth Loss 10962.001075307371 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
[INFO]: Comparing Smooth Loss 11144.855858488723 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
527/535 [============================>.] - ETA: 0s - loss: 20104.7402 - binary_accuracy: 0.5560
[INFO]: Comparing Smooth Loss 11329.031436631449 and Stop Loss 2.7601591997381787
[INFO]: Loss is too high. Stopping training!
解决方案
我有同样的问题,似乎 keras 只会在一个时代结束时中断训练。如果您self.model.stop_training = True
在批次后设置回调,则培训将继续以下批次,直到 epoch 结束,然后才停止。
我发现的一种解决方案是step_per_epoch
在调用tf.keras.Model.fit
. 使用更短的时期,您可以更好地控制停止条件。
推荐阅读
- c# - 需要帮助以在两者之间提取文本
23454
- python - 两只兔子见面
- mysql - MySQL INNER JOIN - 如何添加额外的 IF 语句?
- bootstrap-4 - Bootstrap 4.1 表单验证类在 Angular 7 中不起作用
- digital-ocean - 为什么 gitlab runner 不能在 3Gb Digital Ocean 服务器上运行作业?
- phpmyadmin - sql 导出未显示在 phpmyadmin 版本 4.8.4 中
- css - 如何在 JavaFX 图表中动态更改符号/点大小?
- c# - 如果 Mysql 不是一个选项,Mysql 数据库替代品
- jquery - jQuery .val() 忽略输入的第一个字符
- tsql - INT每月明细总和