python - 当达到特定的损失和准确度值时,如何停止 tflearn 训练时期或迭代?
问题描述
我有一个模型可以使用 tflearn 库进行训练,我使用深度神经网络 (DNN) 来做到这一点。我们可以在这里看到更多(http://tflearn.org/models/dnn/)
下面是我的一段代码:
# Build neural network
net = tflearn.input_data(shape=[None, len(train_x[0])])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, len(train_y[0]), activation='softmax')
net = tflearn.regression(net)
# Define model and setup tensorboard
model = tflearn.DNN(net, tensorboard_dir='tflearn_logs', best_val_accuracy=0.91)
# Start training (apply gradient descent algorithm)
model.fit(train_x, train_y, n_epoch=350, batch_size=8, show_metric=True)
model.save('model.tflearn')
当我运行该代码时,我会得到一些这样的值,直到 epoch 结束:
Training Step: 5083 | total loss: 0.31890 | time: 0.302s
| Adam | epoch: 085 | loss: 0.31890 - acc: 0.8948 -- iter: 344/474
Training Step: 20999 | total loss: 0.08880 | time: 0.366s
....
Training Step: 11279 | total loss: 0.10708 | time: 0.419s
| Adam | epoch: 188 | loss: 0.10708 - acc: 0.9556 -- iter: 472/474
Training Step: 11280 | total loss: 0.12302 | time: 0.425s
| Adam | epoch: 188 | loss: 0.12302 - acc: 0.9351 -- iter: 474/474
....
| Adam | epoch: 350 | loss: 0.08880 - acc: 0.9503 -- iter: 472/474
Training Step: 21000 | total loss: 0.08863 | time: 0.373s
| Adam | epoch: 350 | loss: 0.08863 - acc: 0.9553 -- iter: 474/474
任何人都知道每次损失和准确性达到特定值时如何停止训练?假设损失 0.05 和准确度 0.95。提前致谢
解决方案
Use Early Stopping through a callback instance given as argument to your fit method, like it's described here:
http://mckinziebrandon.me/TensorflowNotebooks/2016/11/20/early-stopping.html
Something like this should work for stopping training when accuracy reaches 0.95
class EarlyStoppingCallback(tflearn.callbacks.Callback):
def __init__(self, val_acc_thresh):
""" Note: We are free to define our init function however we please. """
self.val_acc_thresh = val_acc_thresh
def on_epoch_end(self, training_state):
""" """
# Apparently this can happen.
if training_state.val_acc is None: return
if training_state.val_acc > self.val_acc_thresh:
raise StopIteration
# Initializae our callback.
early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.95)
# Give it to our trainer and let it fit the data.
trainer.fit(feed_dicts={X: trainX, Y: trainY},
val_feed_dicts={X: testX, Y: testY},
n_epoch=2,
show_metric=True, # Calculate accuracy and display at every step.
snapshot_epoch=False,
callbacks=early_stopping_cb)
推荐阅读
- reactjs - 路由组件中的回调导致超过最大更新深度
- c# - WPF MVVM 父子关系
- python-3.x - Python PIL-NEAREST“对象没有属性”
- c# - 如何使用 C# 从 XML 中获取所有子元素名称及其父元素名称?
- c# - 找不到类型或命名空间名称“TextureImporter”(您是否缺少 using 指令或程序集引用?)
- c# - Microsoft.Office.Word.Interop 删除评论
- jquery - 使用 MutationObserver 比较新旧文本内容
- python - Tensorflow 服务错误“{“错误”:“格式错误的请求:POST /v1/models/cloths:predict
- sql - 寻找两个数据输出之间的差距
- react-native - 虚拟键盘将我的所有内容向上移动相同的距离,并且在我键入时某些文本字段不可见(Android)