python - 迭代和数据集长度出错。并且 data_time_interval 太大。如何解决这些问题?
问题描述
我有 4 个 GPU,想使用 Pytorch DDP 进行训练。我测试的数据集train有80个样本,val有20个样本,设置的batch size是16。我想把训练时的迭代次数i和训练集的长度打印为args.step_per_epoch = train_data。伦()。下面是train和val的代码。
def train(model, local_rank):
model_path = '/output/saved_model'
step = 0
nr_eval = 0
dataset = VimeoDataset(mode = 'train')
sampler = DistributedSampler(dataset)
train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, drop_last=True, sampler=sampler)
args.step_per_epoch = train_data.__len__()
dataset_val = VimeoDataset(mode = 'val')
val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=4)
evaluate(model, val_data, nr_eval, local_rank)
epochlist.append(0)
model.save_model(model_path, local_rank)
print('training...')
time_stamp = time.time()
for epoch in range(args.epoch):
sampler.set_epoch(epoch)
for i, data in enumerate(train_data):
data_time_interval = time.time() - time_stamp
time_stamp = time.time()
data_gpu, flow_gt = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
flow_gt = flow_gt.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
learning_rate = get_learning_rate(step)
pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
train_time_interval = time.time() - time_stamp
time_stamp = time.time()
if local_rank == 0:
print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_LPIPS:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_LPIPS))
step += 1
nr_eval += 1
if nr_eval % 5 == 0:
evaluate(model, val_data, step, local_rank)
epochlist.append(nr_eval)
model.save_model(model_path, local_rank)
dist.barrier()
def evaluate(model, val_data, nr_eval, local_rank):
psnr_list = []
time_stamp = time.time()
for i, data in enumerate(val_data):
data_gpu, flow_gt = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
flow_gt = flow_gt.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
with torch.no_grad():
pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, training=False)
for j in range(gt.shape[0]):
psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
psnr_list.append(psnr)
eval_time_interval = time.time() - time_stamp
if local_rank == 0:
print('eval time: {}'.format(eval_time_interval))
print('mean psnr: {}'.format(np.mean(psnr_list)))
psnrlist.append(np.mean(psnr_list))
以下是部分任务训练日志。可以看到 i 和 args.step_per_epoch 等参数值是错误的,总是0/1。是什么原因以及如何修改它。此外,还可以看到 data_time_interval 实际上大于 train_time_interval。如何减少 data_time_interval 以提高训练效率。
eval time: 8.030341386795044
mean psnr: 24.4023611466035
training...
epoch:0 0/1 time:0.30+2.96 loss_LPIPS:3.4511e-01
epoch:1 0/1 time:0.93+0.41 loss_LPIPS:3.3067e-01
epoch:2 0/1 time:1.29+0.36 loss_LPIPS:3.3386e-01
epoch:3 0/1 time:4.61+0.37 loss_LPIPS:3.2475e-01
epoch:4 0/1 time:5.26+0.36 loss_LPIPS:3.1935e-01
eval time: 0.9092228412628174
mean psnr: 24.403575688421018
epoch:5 0/1 time:3.94+0.36 loss_LPIPS:3.5425e-01
epoch:6 0/1 time:4.75+0.36 loss_LPIPS:3.5130e-01
epoch:7 0/1 time:3.60+0.36 loss_LPIPS:3.2492e-01
epoch:8 0/1 time:1.40+0.37 loss_LPIPS:3.4967e-01
epoch:9 0/1 time:1.12+0.37 loss_LPIPS:3.4065e-01
...
...
epoch:90 0/1 time:7.11+0.38 loss_LPIPS:3.1828e-01
epoch:91 0/1 time:2.66+0.37 loss_LPIPS:2.7712e-01
epoch:92 0/1 time:6.57+0.36 loss_LPIPS:3.0946e-01
epoch:93 0/1 time:5.59+0.36 loss_LPIPS:2.5663e-01
epoch:94 0/1 time:0.76+0.36 loss_LPIPS:2.8386e-01
eval time: 0.8664124011993408
mean psnr: 24.744649141015156
epoch:95 0/1 time:1.89+0.35 loss_LPIPS:2.8509e-01
epoch:96 0/1 time:2.24+0.36 loss_LPIPS:3.0353e-01
epoch:97 0/1 time:1.49+0.37 loss_LPIPS:3.0354e-01
epoch:98 0/1 time:1.34+0.36 loss_LPIPS:2.9313e-01
epoch:99 0/1 time:1.27+0.36 loss_LPIPS:2.9234e-01
eval time: 0.8701093196868896
mean psnr: 24.784283493153136
解决方案
推荐阅读
- c - Scanf 跳过函数
- facebook - Whitelabel 上的 Facebook 像素
- android - firebase 实时数据库是否支持批量写入/事务?
- c++ - 如何将 utf8 转换为 std::string?
- html - 尝试获取 Bulma 像素细节:底部边框与文本宽度相同的选项卡
- javascript - Navbar Position remain fixed when scroll back to top
- python - 如何在 GET 请求中以正确的格式显示 JSON 数据
- ckan - CKAN API 密钥轮换
- html - 我想使用 css 在绝对定位的图像后面放置一个固定背景
- java - 处理上的 Java 序列化