首页 > 解决方案 > 分布式时如何划分数据集

问题描述

现在我想将数据集分为两部分:训练集和验证集。我知道在单个 GPU 上我可以使用采样器做到这一点:

indices = list(range(len(train_data)))
train_loader = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

但是当我想使用 以并行方式训练它时torch.distributed,我必须使用另一个采样器,即,sampler = torch.utils.data.distributed.DistributedSampler(train_data)

那么我应该如何使用这两个采样器,以便我可以划分数据集并同时分发它?

非常感谢您的帮助!

标签: pythonpytorchdistributed

解决方案


torch.utils.data.Dataset您可以在创建之前拆分torch.utils.data.DataLoader

只需像这样使用torch.utils.data.random_split

train, validation =
    torch.utils.data.random_split(
        dataset, 
        (len(dataset)-val_length, val_length)
    )

这将为您提供两个单独的数据集,可以根据需要与数据加载器一起使用。


推荐阅读