首页 > 解决方案 > 在 tensorflow 2.0 中使用自定义训练循环时,Val 损失的行为很奇怪

问题描述

我正在使用编写的 VGG16 模型tf2.0在我自己的数据集上进行训练。模型中包含了一些 BatchNormalization 层,并且"training"参数设置为True在训练时间和False验证时间期间,如许多教程中所述。train_loss 在训练过程中按预期下降到一定水平。
但是, val_loss 的行为非常奇怪。我在训练后检查了模型的输出,发现如果我将training参数设置为True,则输出非常正确,但如果我将其设置为False,则结果完全不正确。根据 tensorflow 网站上的教程,whentraining设置为False,该模型将使用其在训练期间学习的移动统计数据的均值和方差对其输入进行归一化,但事实并非如此。我错过了什么吗?

我在下面提供了训练和验证代码。

def train():
    logging.basicConfig(level=logging.INFO)
    tdataset = tf.data.Dataset.from_tensor_slices((train_img_list[:200], train_label_list[:200]))
    tdataset = tdataset.map(parse_function, 3).shuffle(buffer_size=200).batch(batch_size).repeat(repeat_times)
    vdataset = tf.data.Dataset.from_tensor_slices((val_img_list[:100], val_label_list[:100]))
    vdataset = vdataset.map(parse_function, 3).batch(batch_size)

    ### Vgg model
    model = VGG_PR(num_classes=num_label)

    logging.info('Model loaded')

    start_epoch = 0
    latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
    if latest_ckpt:
        start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
        model.load_weights(latest_ckpt)
        logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
    else:
        logging.info('training from scratch since weights no there')

    ######## training loop ########
    loss_object = tf.keras.losses.MeanSquaredError()
    val_loss_object = tf.keras.losses.MeanSquaredError()
    optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr)
    train_loss = tf.metrics.Mean(name='train_loss') 
    val_loss = tf.metrics.Mean(name='val_loss')
    writer = tf.summary.create_file_writer(log_path.format(case_num))

    with writer.as_default():
        for epoch in range(start_epoch, total_epoch):
            print('start training')
            try:
                for batch, data in enumerate(tdataset):
                    images, labels = data
                    with tf.GradientTape() as tape:
                        pred = model(images, training=True)
                        if len(pred.shape) == 2:
                            pred = tf.reshape(pred,[-1, 1, 1, num_label])
                        loss = loss_object(pred, labels)
                    gradients = tape.gradient(loss, model.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                    if batch % 20 ==0:
                        logging.info('Epoch: {}, iter: {}, loss:{}'.format(epoch, batch, loss.numpy()))
                    tf.summary.scalar('train_loss', loss.numpy(), step=epoch*1250*repeat_times+batch)      # the tdataset has been repeated 5 times..
                    tf.summary.text('Zernike_coe_pred', tf.as_string(tf.squeeze(pred)), step=epoch*1250*repeat_times+batch)
                    tf.summary.text('Zernike_coe_gt', tf.as_string(tf.squeeze(labels)), step=epoch*1250*repeat_times+batch)

                    writer.flush()
                    train_loss(loss)
                model.save_weights(ckpt_path.format(epoch=epoch))
            except KeyboardInterrupt:
                logging.info('interrupted.')
                model.save_weights(ckpt_path.format(epoch=epoch))
                logging.info('model saved into {}'.format(ckpt_path.format(epoch=epoch)))
                exit(0)
            # validation step
            for batch, data in enumerate(vdataset):
                images, labels = data
                val_pred = model(images, training=False)
                if len(val_pred.shape) == 2:
                    val_pred = tf.reshape(val_pred,[-1, 1, 1, num_label])
                v_loss = val_loss_object(val_pred, labels)
                val_loss(v_loss)
            logging.info('Epoch: {}, average train_loss:{}, val_loss: {}'.format(epoch, train_loss.result(), val_loss.result()))
            tf.summary.scalar('val_loss', val_loss.result(), step = epoch)
            writer.flush()
            train_loss.reset_states()
            val_loss.reset_states()
        model.save_weights(ckpt_path.format(epoch=epoch))

训练损失减少到非常小的值,如 groundtruth 标签在 [0, 1] 范围内,平均训练损失可以是 0.007,但 val 损失远高于此。training如果我设置为,模型的输出趋向于接近 0 False

11 月 6 日更新:我发现了一个有趣的事情,如果我用它的方法tf.function来装饰我的模型,会变成正确的,但我不确定发生了什么?callval loss

标签: python-3.xtensorflowdeep-learning

解决方案


为了社区的利益提及答案。

问题已解决,即如果在其方法中用于装饰val loss将变为正确。tf.functionmodelcall


推荐阅读