首页 > 解决方案 > 如何达到满足提前停止条件的纪元数

问题描述

如果满足某些条件,我会使用回调来停止训练过程。我想知道如何访问由于回调而停止训练的纪元数。

import numpy as np
import random
import tensorflow as tf
from tensorflow import keras 


class stopAtLossValue(tf.keras.callbacks.Callback):
        def on_batch_end(self, batch, logs={}):
            eps = 0.01 
           
            if logs.get('loss') <= eps:
                 self.model.stop_training = True
                    
training_input=  np.random.random ([30,10])
training_output = np.random.random ([30,1])



model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(10,)),
    tf.keras.layers.Dense(15,activation=tf.keras.activations.linear),
    tf.keras.layers.Dense(15, activation='relu'),
    tf.keras.layers.Dense(1) 
])   
                               
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))

hist = model.fit(training_input, training_output, epochs=100, batch_size=100,  verbose=1, callbacks=[stopAtLossValue()])
    

对于这个例子,我的训练在第 66 个 epoch 完成,因为损失低于 0.01。

Epoch 66/100
1/1 [==============================] - 0s 5ms/step - loss: 0.0099
----------------------------------------------------------------- 

标签: pythontensorflowkerastensorflow2.0

解决方案


简单的方法是获取history.history对象的长度:

len(model.history.history['loss'])

更复杂的方法是从优化器获取迭代次数:

model.optimizer._iterations

推荐阅读