python - 当验证损失满足某些标准时提前停止
问题描述
我正在 Keras 中训练一个神经网络模型。我想监控验证损失并在达到特定条件时停止训练。
我知道当给定轮EarlyStopping
数的训练没有改善时,我可以停止训练。patience
我想要一些不同的东西。我想在回合后val_loss
超过一个值时停止训练。x
n
为了清楚起见,让我们说x
in0.5
和n
is 50
。epoch
我只想在数字大于50
且val_loss
高于时才停止模型的训练0.5
。
我怎么能在 Keras 中做到这一点。?
解决方案
您可以通过继承 KerasEarlyStopping
回调并用您自己的逻辑覆盖它来定义自己的回调:
from keras.callbacks import EarlyStopping # use as base class
class MyCallBack(EarlyStopping):
def __init__(self, threshold, min_epochs, **kwargs):
super(MyCallBack, self).__init__(**kwargs)
self.threshold = threshold # threshold for validation loss
self.min_epochs = min_epochs # min number of epochs to run
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return
# implement your own logic here
if (epoch >= self.min_epochs) & (current >= self.threshold):
self.stopped_epoch = epoch
self.model.stop_training = True
小例子来说明它应该工作:
from keras.layers import Input, Dense
from keras.models import Model
import numpy as np
# Generate some random data
features = np.random.rand(100, 5)
labels = np.random.rand(100, 1)
validation_feat = np.random.rand(100, 5)
validation_labels = np.random.rand(100, 1)
# Define a simple model
input_layer = Input((5, ))
dense_layer = Dense(10)(input_layer)
output_layer = Dense(1)(dense_layer)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='mse', optimizer='sgd')
# Fit with custom callback
callbacks = [MyCallBack(threshold=0.001, min_epochs=10, verbose=1)]
model.fit(features, labels, validation_data=(validation_feat, validation_labels), callbacks=callbacks, epochs=100)
推荐阅读
- c# - C# System.Timers.Timer 类已用事件和定时器的一般注意事项
- c# - 在 swagger 中为 web api 中的特定版本添加 Bearer token 选项
- python - 如何在 python 中处理泡菜文件时找到准确性?
- javascript - 将 Runtime 更改为 Node.js 10 后云功能失败
- reactjs - 使用 refs 和 state onClick 定位 React 循环内的特定子项
- wix - 安装新版本时如何防止wix卸载旧版本?
- c++ - Visual Studio Code 的 C/C++ 扩展“变量“u8””不是类型名称”
- java - 损坏的 unicode 渲染,Java,仅限 Windows
- php - php中的sharepoint webhook集成不起作用
- docker - 使用 open shift dsl 从 open shift 中删除 docker 图像