首页 > 解决方案 > 迭代和数据集长度出错。并且 data_time_interval 太大。如何解决这些问题?

问题描述

我有 4 个 GPU,想使用 Pytorch DDP 进行训练。我测试的数据集train有80个样本,val有20个样本,设置的batch size是16。我想把训练时的迭代次数i和训练集的长度打印为args.step_per_epoch = train_data。()。下面是train和val的代码。

def train(model, local_rank):
    model_path = '/output/saved_model'
    step = 0
    nr_eval = 0
    dataset = VimeoDataset(mode = 'train')
    sampler = DistributedSampler(dataset)
    train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, drop_last=True, sampler=sampler)
    args.step_per_epoch = train_data.__len__()
    dataset_val = VimeoDataset(mode = 'val')
    val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=4)
    evaluate(model, val_data, nr_eval, local_rank)
    epochlist.append(0)
    model.save_model(model_path, local_rank)
    print('training...')
    time_stamp = time.time()
    for epoch in range(args.epoch):
        sampler.set_epoch(epoch)
        for i, data in enumerate(train_data):
            data_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            data_gpu, flow_gt = data
            data_gpu = data_gpu.to(device, non_blocking=True) / 255.
            flow_gt = flow_gt.to(device, non_blocking=True)
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
            learning_rate = get_learning_rate(step)
            pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            if local_rank == 0:
                print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_LPIPS:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_LPIPS))
            step += 1
        nr_eval += 1
        if nr_eval % 5 == 0:
            evaluate(model, val_data, step, local_rank)
            epochlist.append(nr_eval)
        model.save_model(model_path, local_rank)    
        dist.barrier()

def evaluate(model, val_data, nr_eval, local_rank):
    psnr_list = []
    time_stamp = time.time()
    for i, data in enumerate(val_data):
        data_gpu, flow_gt = data
        data_gpu = data_gpu.to(device, non_blocking=True) / 255.
        flow_gt = flow_gt.to(device, non_blocking=True)
        imgs = data_gpu[:, :6]
        gt = data_gpu[:, 6:9]
        with torch.no_grad():
            pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, training=False)
        for j in range(gt.shape[0]):
            psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
            psnr_list.append(psnr)
    
    eval_time_interval = time.time() - time_stamp
    if local_rank == 0:
        print('eval time: {}'.format(eval_time_interval)) 
        print('mean psnr: {}'.format(np.mean(psnr_list)))
        psnrlist.append(np.mean(psnr_list))

以下是部分任务训练日志。可以看到 i 和 args.step_per_epoch 等参数值是错误的,总是0/1。是什么原因以及如何修改它。此外,还可以看到 data_time_interval 实际上大于 train_time_interval。如何减少 data_time_interval 以提高训练效率。

eval time: 8.030341386795044
mean psnr: 24.4023611466035
training...
epoch:0 0/1 time:0.30+2.96 loss_LPIPS:3.4511e-01
epoch:1 0/1 time:0.93+0.41 loss_LPIPS:3.3067e-01
epoch:2 0/1 time:1.29+0.36 loss_LPIPS:3.3386e-01
epoch:3 0/1 time:4.61+0.37 loss_LPIPS:3.2475e-01
epoch:4 0/1 time:5.26+0.36 loss_LPIPS:3.1935e-01
eval time: 0.9092228412628174
mean psnr: 24.403575688421018
epoch:5 0/1 time:3.94+0.36 loss_LPIPS:3.5425e-01
epoch:6 0/1 time:4.75+0.36 loss_LPIPS:3.5130e-01
epoch:7 0/1 time:3.60+0.36 loss_LPIPS:3.2492e-01
epoch:8 0/1 time:1.40+0.37 loss_LPIPS:3.4967e-01
epoch:9 0/1 time:1.12+0.37 loss_LPIPS:3.4065e-01
...
...
epoch:90 0/1 time:7.11+0.38 loss_LPIPS:3.1828e-01
epoch:91 0/1 time:2.66+0.37 loss_LPIPS:2.7712e-01
epoch:92 0/1 time:6.57+0.36 loss_LPIPS:3.0946e-01
epoch:93 0/1 time:5.59+0.36 loss_LPIPS:2.5663e-01
epoch:94 0/1 time:0.76+0.36 loss_LPIPS:2.8386e-01
eval time: 0.8664124011993408
mean psnr: 24.744649141015156
epoch:95 0/1 time:1.89+0.35 loss_LPIPS:2.8509e-01
epoch:96 0/1 time:2.24+0.36 loss_LPIPS:3.0353e-01
epoch:97 0/1 time:1.49+0.37 loss_LPIPS:3.0354e-01
epoch:98 0/1 time:1.34+0.36 loss_LPIPS:2.9313e-01
epoch:99 0/1 time:1.27+0.36 loss_LPIPS:2.9234e-01
eval time: 0.8701093196868896
mean psnr: 24.784283493153136

标签: pythondeep-learningpytorchdataset

解决方案


推荐阅读