首页 > 解决方案 > Pytorch - 自定义 DataLoader 永远运行

问题描述

class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [0]*10000000

    def __getitem__(self, index):
        pid = os.getpid() % WORKER_SIZE
        # My code here only uses pid, doesnt use index

        return torch.tensor(batch.data), torch.tensor(batch.label)

    def __len__(self):
        return len(self.data)

我需要我的数据加载器永远运行。现在它总是在达到 10000000 或任何最大整数大小后终止。我如何让它永远运行,我不关心'索引'我没有使用它。我只是在使用这个类的工人能力

标签: pytorch

解决方案


由于您需要对同一批次进行多次迭代训练,因此以下代码框架应该适合您。

def train(args, data_loader):
    for idx, ex in enumerate(data_loader):
        # iterate over each mini-batches
        # add your code

def validate(args, data_loader):
     with torch.no_grad():
        for idx, ex in enumerate(data_loader):
            # iterate over each mini-batches
            # add your code

# args = dict() containing required parameters
for epoch in range(start_epoch, args.num_epochs):
    # train_loader = data loader for the training data
    train(args, train_loader)

您可以按如下方式使用数据加载器。

class ReaderDataset(Dataset):
    def __init__(self, examples):
        # examples = a list of examples
        # add your code

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch item

train_dataset = ReaderDataset(train_examples)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.data_workers,
        collate_fn=batchify,
        pin_memory=args.cuda,
        drop_last=args.parallel
    )
# batchify is a custom function to prepare the mini-batches

推荐阅读