首页 > 解决方案 > 在 PyTorch 中对深度学习模型进行评估(例如测试、验证)以便正确保存检查点的正确方法是什么?

问题描述

我通常在代码中的训练和测试之间交替进行。这导致我mdl.train()在回到 trianing 循环和mdl.test()进入测试循环时调用。但是,我注意到这可能会导致取决于此评估、训练状态的参数的保存方式(例如批量规范,也许还有其他)产生问题。因此,它引出了一个问题,如果有的话,应该如何调用这些标志。如果我在检查点之前调用评估,我相信这将删除我的运行平均值,如此处所述:如何使用 Batch Norm 训练中的均值和标准差?.

我认为在评估期间模型的深层副本应该可以解决问题 - 但这似乎会杀死我的 GPU 内存(也可能是我的正常内存)。

因此,什么是在 PyTorch 中交替评估和训练以便正确保存检查点的正确方法(例如,不删除训练的运行统计信息)。

这是一个挑战,因为我通常总是运行评估以查看当前模型在验证方面是否比以前的模型更好,然后决定是否保存它 - 这需要我在检查点之前进行评估 - 总是。

也许,我总是可以在 train 上运行它,以便正确保存运行平均值,但由于实际验证值无关紧要,如果 train 和 val 的统计数据泄漏到验证中也没关系。

我想一直使用批处理统计运行所有内容是另一种选择...

我通常用于 MAML 的一些代码片段(元学习,但应该很容易适应正常的监督学习):

def meta_train_fixed_iterations_full_epoch_possible(args):
    """
    Train using the meta-training (e.g. episodic training) over batches of tasks using a fixed number of iterations
    assuming the the number of tasks is small i.e. one epoch is doable and not infinite/super exponential
    (e.g. in regression when a task can be considered as a function).

    Note: if num tasks is small then we have two loops, one while we have not finished all fixed its and the other
    over the dataloader for the tasks.
    """
    # warnings.simplefilter("ignore")
    # uutils.torch_uu.distributed.dist_log('Starting training...')
    print('Strating training!')

    # bar_it = uutils.get_good_progressbar(max_value=progressbar.UnknownLength)
    bar_it = uutils.get_good_progressbar(max_value=args.num_its)
    args.it = 0
    while True:
        for batch_idx, batch in enumerate(args.dataloaders['train']):
            args.batch_idx = batch_idx
            spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)

            # - clean gradients, especially before meta-learner is ran since it uses gradients
            args.outer_opt.zero_grad()

            # - forward pass A(f)(x)
            train_loss, train_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)

            # - outer_opt step
            gradient_clip(args, args.outer_opt)  # do gradient clipping: * If ‖g‖ ≥ c Then g := c * g/‖g‖
            args.outer_opt.step()

            # - scheduler
            if (args.it % 500 == 0 and args.it != 0 and args.scheduler is not None) or args.debug:  # call scheduler every
                args.scheduler.step()

            # -- log it stats
            log_train_val_stats(args, args.it, train_loss, train_acc, valid=meta_eval, bar=bar_it,
                                log_freq=100, ckpt_freq=500,
                                save_val_ckpt=True, log_to_wandb=args.log_to_wandb)
            log_sim_to_check_presence_of_feature_reuse(args, args.it,
                                                       spt_x, spt_y, qry_x, qry_y,
                                                       log_freq_for_detection_of_feature_reuse=int(args.num_its//3)
                                                       , parallel=False)

            # - break
            halt: bool = args.it >= args.num_its - 1
            if halt:
                return train_loss, train_acc

            args.it += 1

# - evaluation code

def meta_eval(args: Namespace, val_iterations: int = 0, save_val_ckpt: bool = True, split: str = 'val') -> tuple:
    """
    Evaluates the meta-learner on the given meta-set.

    ref for BN/eval:
        - https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du
        - https://github.com/tristandeleu/pytorch-maml/issues/19
    """
    # - need to re-implement if you want to go through the entire data-set to compute an epoch (no more is ever needed)
    assert val_iterations == 0, f'Val iterations has to be zero but got {val_iterations}, if you want more precision increase (meta) batch size.'
    args.meta_learner.eval()
    for batch_idx, batch in enumerate(args.dataloaders[split]):
        spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)

        # Forward pass
        eval_loss, eval_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)

        # store eval info
        if batch_idx >= val_iterations:
            break

    save_val_ckpt = False if split == 'test' else save_val_ckpt  # don't save models based on test set
    if float(eval_loss) < float(args.best_val_loss) and save_val_ckpt:
        args.best_val_loss = float(eval_loss)
        save_for_meta_learning(args, ckpt_filename='ckpt_best_val.pt')
    return eval_loss, eval_acc

有关的:

标签: machine-learningdeep-learningpytorchconv-neural-network

解决方案


推荐阅读