python - 检查 val_acc 饱和度的回调
问题描述
通常,如果准确度达到一定水平,我们可以为模型定义一个回调以停止 epoch。
def LSTM_model(X_train, y_train, X_test, y_test, num_classes, batch_size=68, units=128, learning_rate=0.005, epochs=20,
dropout=0.2, recurrent_dropout=0.2):
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if (logs.get('acc') > 0.90):
print("\nReached 90% accuracy so cancelling training!")
self.model.stop_training = True
callbacks = myCallback()
如图所示,val_acc(orange) 在一个范围内波动,并且不再真正上升。
一旦 val_acc 的总体趋势停止增加,有没有办法自动停止训练?
解决方案
您可以通过这样的方式实现此callback
目的
class terminate_on_plateau(keras.callbacks.Callback):
def __init__(self):
self.patience = 10
self.val_loss = deque([],self.patience)
self.std_threshold = 1e-2
def on_epoch_end(self,epoch,logs=None):
val_loss,val_mae = model.evaluate(x_val,y_val)
self.val_loss.append(val_loss)
if len(self.val_loss) >= self.patience:
std = np.std(self.val_loss)
if std < self.std_threshold:
print('\n\n EarlyStopping on std invoked! \n\n')
# clear the deque
self.val_loss = deque([],self.patience)
model.stop_training = True
如您所见, in terminate_on_plateau
, val_loss
of epoch 存储在 adeque
的 max lengthself.patience
中。deque
一旦达到的长度,将为每个新的 epoch计算的self.patience
标准差,如果计算的值小于阈值,则终止训练过程( 的也将被清除) 。val_loss
deque
val_loss
std
下面是一个简单的脚本,向您展示如何使用它
from collections import deque
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense
x = np.linspace(0,10,1000)
np.random.shuffle(x)
y = np.sin(x) + x
x_train,x_val,y_train,y_val = train_test_split(x,y,test_size=0.3)
input_x = Input(shape=(1,))
y = Dense(10,activation='relu')(input_x)
y = Dense(10,activation='relu')(y)
y = Dense(1,activation='relu')(y)
model = Model(inputs=input_x,outputs=y)
adamopt = tf.keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
class terminate_on_plateau(keras.callbacks.Callback):
def __init__(self):
self.patience = 10
self.val_loss = deque([],self.patience)
self.std_threshold = 1e-2
def on_epoch_end(self,epoch,logs=None):
val_loss,val_mae = model.evaluate(x_val,y_val)
self.val_loss.append(val_loss)
if len(self.val_loss) >= self.patience:
std = np.std(self.val_loss)
if std < self.std_threshold:
print('\n\n EarlyStopping on std invoked! \n\n')
# clear the deque
self.val_loss = deque([],self.patience)
model.stop_training = True
model.compile(loss='mse',optimizer=adamopt,metrics=['mae'])
history = model.fit(x_train,y_train,
batch_size=8,
epochs=100,
validation_data=(x_val, y_val),
verbose=1,
callbacks=[terminate_on_plateau()])
推荐阅读
- c# - BeginSendFile SocketException '参数不正确'
- azure-application-insights - Application Insights - 如何关联跨多个服务的操作?
- android - Android测试用例等待10分钟后才能在debug模式下调试
- python - scikit learn ExtraTreesClassifier 预测使用 Pandas DataFarme vs datatale Frame vs Numpy array 给出不同的执行时间
- vim - 尝试(但失败)让 cscope/ctags 在混合 C/C++ 项目中定位 C++ 函数
- html - Nikola:添加带有 id 的链接
- java - (GAE-Standard+Java11) 运行多个实例的会话
- python - 在 Pandas Dataframe 中计算时间间隔内的行数
- grails - Grails 视图为 DTO 对象列表呈现额外的逗号
- c# - 在 Azure 搜索中将模糊搜索与同义词扩展相结合