首页 > 解决方案 > tqdm 笔记本 - 2 个内条

问题描述

我正在尝试使用 tqdm 打印训练进度条。
我想跟踪时代的进度,对于每个时代,我都有 2 个进度条:train_loader minibatches 和 validation_loader minibatches。
代码是这样的:

for epoch in tqdm(range(epochs_num)):
    for inputs, labels in tqdm(train_loader, "Train progress", leave=False):
        # train...
    with torch.no_grad():
        for inputs, labels in tqdm(validation_loader, "Validation progress", leave=False):
            # calc validation loss

使用该leave参数,进度条删除了每个时期,但我想在验证过程结束后将它们一起删除。
有什么办法吗?

谢谢

标签: pythonpytorchtraining-datatqdm

解决方案


您可以重复使用进度条并手动进行更新,如下所示:

epochs = tqdm(range(epochs_num), desc="Epochs")
training_progress = tqdm(total=training_batch_size, desc="Training progress")
validation_progress = tqdm(total=validation_batch_size, desc="Validation progress")

for epoch in epochs:
    training_progress.reset()
    validation_progress.reset()
    
    for inputs, labels in train_loader:
        # train...
        training_progress.update()
        
    with torch.no_grad():
        for inputs, labels in validation_loader:
            # calc validation loss
            validation_progress.update()

如果批量大小并不总是相同,您可以即时计算它们并调用reset(total=new_size).


推荐阅读