首页 > 解决方案 > 回调不在张量流中工作以停止训练

问题描述

我写了一个回调,当准确度达到 99% 时停止训练。但问题是我收到此错误。有时如果我解决此错误,即使 acuurqacy 变为 100%,也不会调用回调。

'NoneType' 和 'float' 的实例之间不支持 '>'

    class myCallback(tf.keras.callbacks.Callback):
        
        def on_epoch_end(self, epoch, logs={}):
            
            if(logs.get('accuracy') > 0.99):
                
                
               
               self.model.stop_training = True


def train_mnist():
    # Please write your code only where you are indicated.
    # please do not remove # model fitting inline comments.

    # YOUR CODE SHOULD START HERE

    # YOUR CODE SHOULD END HERE
    call = myCallback()
    mnist = tf.keras.datasets.mnist

    (x_train, y_train),(x_test, y_test) = mnist.load_data(path=path)
    # YOUR CODE SHOULD START
    x_train = x_train/255
    y_train = y_train/255
    # YOUR CODE SHOULD END HERE
    model = tf.keras.models.Sequential([
        # YOUR CODE SHOULD START HERE
          keras.layers.Flatten(input_shape=(28,28)),
          keras.layers.Dense(128,activation='relu'),
          keras.layers.Dense(10,activation='softmax')
        # YOUR CODE SHOULD END HERE
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    # model fitting
    history = model.fit(# YOUR CODE SHOULD START HERE
          x_train,y_train,epochs=9,callbacks=[call] )
    # model fitting
    return history.epoch, history.history['acc'][-1]

标签: pythontensorflowkeras

解决方案


上述代码的两个主要问题:

  • 在训练集上达到 100% 的准确率几乎总是意味着您的模型过度拟合。那很糟糕。您想要做的是validation_split=.2在方法中指定参数.fit,并在验证集上寻找高精度。
  • 您尝试在自定义回调中构建的内容已经完成keras.callbacks.EarlyStopping,它甚至可以选择在每个时期恢复到最佳整体模型。而且,默认情况下,如果您有验证拆分,它会寻找验证准确度,而不是训练准确度。

所以,这就是你应该做的:停止使用自定义回调,它们需要一些掌握才能开始工作。改用EarlyStoppingwith restore_best像这样 总是validation_split在验证集中使用并寻找高精度。就像在这个快速示例中一样


使用内置回调是否解决了您的问题?


推荐阅读