首页 > 解决方案 > 仅当 TensorFlow 模型改进了训练和验证数据时才应用梯度下降

问题描述

我想自定义fit模型的功能,以便仅当模型改进了对验证数据的预测时才对权重应用梯度下降。这样做的原因是我想防止过度拟合。

根据本指南,应该可以自定义fit模型的功能。但是,以下代码会遇到错误:

class CustomModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        ### check and apply gradient
        Y_pred_val = self.predict(X_val)                 # this does not work
        acc_val = calculate_accuracy(Y_val, Y_pred_val)

        if acc_val > last_acc_val:
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        ###

        self.compiled_metrics.update_state(y, y_pred)

        return_obj = {m.name: m.result() for m in self.metrics}
        return_obj["acc_val"] = acc_val
        return return_obj

怎么可能评估fit函数内部的模型?

标签: pythonpython-3.xtensorflowmachine-learningkeras

解决方案


您不必为此进行子类化fit()。您可以制作一个自定义的训练循环。看看我是怎么做到的:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras import Model
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Concatenate
import tensorflow_datasets as tfds
from tensorflow.keras.regularizers import l1, l2, l1_l2
from collections import deque

dataset, info = tfds.load('mnist',
                          with_info=True,
                          split='train',
                          as_supervised=False)

TAKE = 1_000

data = dataset.map(lambda x: (tf.cast(x['image'],
                       tf.float32), x['label'])).shuffle(TAKE).take(TAKE)

len_train = int(8e-1*TAKE)

train = data.take(len_train).batch(8)
test = data.skip(len_train).take(info.splits['train'].num_examples - len_train).batch(8)


class CNN(Model):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = Dense(32, activation=tf.nn.relu,
                            kernel_regularizer=l1(1e-2),
                            input_shape=info.features['image'].shape)
        self.layer2 = Conv2D(filters=16,
                             kernel_size=(3, 3),
                             strides=(1, 1),
                             activation='relu',
                             input_shape=info.features['image'].shape)
        self.layer3 = MaxPooling2D(pool_size=(2, 2))
        self.layer4 = Conv2D(filters=32,
                             kernel_size=(3, 3),
                             strides=(1, 1),
                             activation=tf.nn.elu,
                             kernel_initializer=tf.keras.initializers.glorot_normal)
        self.layer5 = MaxPooling2D(pool_size=(2, 2))
        self.layer6 = Flatten()
        self.layer7 = Dense(units=64,
                            activation=tf.nn.relu,
                            kernel_regularizer=l2(1e-2))
        self.layer8 = Dense(units=64,
                            activation=tf.nn.relu,
                            kernel_regularizer=l1_l2(l1=1e-2, l2=1e-2))
        self.layer9 = Concatenate()
        self.layer10 = Dense(units=info.features['label'].num_classes)

    def call(self, inputs, training=None, **kwargs):
        b = self.layer1(inputs)
        a = self.layer2(inputs)
        a = self.layer3(a)
        a = self.layer4(a)
        a = self.layer5(a)
        a = self.layer6(a)
        a = self.layer8(a)
        b = self.layer7(b)
        b = self.layer6(b)
        x = self.layer9([a, b])
        x = self.layer10(x)
        return x


cnn = CNN()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

optimizer = tf.keras.optimizers.Nadam()

template = 'Epoch {:3} Train Loss {:7.4f} Test Loss {:7.4f} ' \
           'Train Acc {:6.2%} Test Acc {:6.2%} '

epochs = 5
early_stop = epochs//50

loss_hist = deque()
acc_hist = deque(maxlen=1)
acc_hist.append(0)

for epoch in range(1, epochs + 1):
    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for images, labels in train:
        with tf.GradientTape() as tape:
            logits = cnn(images, training=True)
            loss = loss_object(labels, logits)
            train_loss(loss)
            train_acc(labels, logits)

            current_acc = tf.metrics.SparseCategoricalAccuracy()(labels, logits)

            if tf.greater(current_acc, acc_hist[-1]):
                print('IMPROVEMENT.')
                gradients = tape.gradient(loss, cnn.trainable_variables)
                optimizer.apply_gradients(zip(gradients, cnn.trainable_variables))
                acc_hist.append(current_acc)

    for images, labels in test:
        logits = cnn(images, training=False)
        loss = loss_object(labels, logits)
        test_loss(loss)
        test_acc(labels, logits)

    print(template.format(epoch,
                          train_loss.result(),
                          test_loss.result(),
                          train_acc.result(),
                          test_acc.result()))

    if len(loss_hist) > early_stop and loss_hist.popleft() < min(loss_hist):
        print('Early stopping. No validation loss decrease in %i epochs.' % early_stop)
        break

输出:

IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
Epoch   1 Train Loss 21.1698 Test Loss 21.3391 Train Acc 37.13% Test Acc 38.50% 
IMPROVEMENT.
IMPROVEMENT.
IMPROVEMENT.
Epoch   2 Train Loss 13.8314 Test Loss 12.2496 Train Acc 50.88% Test Acc 52.50% 
Epoch   3 Train Loss 13.7594 Test Loss 12.5884 Train Acc 51.75% Test Acc 53.00% 
Epoch   4 Train Loss 13.1418 Test Loss 13.2374 Train Acc 52.75% Test Acc 51.50% 
Epoch   5 Train Loss 13.6471 Test Loss 13.3157 Train Acc 49.63% Test Acc 51.50% 

这是完成这项工作的部分。它是 a deque,如果最后一个元素deque较小,它会跳过渐变的应用。

    for images, labels in train:
        with tf.GradientTape() as tape:
            logits = cnn(images, training=True)
            loss = loss_object(labels, logits)
            train_loss(loss)
            train_acc(labels, logits)

            current_acc = tf.metrics.SparseCategoricalAccuracy()(labels, logits)

            if tf.greater(current_acc, acc_hist[-1]):
                print('IMPROVEMENT.')
                gradients = tape.gradient(loss, cnn.trainable_variables)
                optimizer.apply_gradients(zip(gradients, cnn.trainable_variables))
                acc_hist.append(current_acc)

推荐阅读