首页 > 解决方案 > GPU 内存概念中的 PyTorch 数据加载器“数据加载”

问题描述

我正在使用 pytorch 数据加载器来加载我的 7.5 GB 数据集。我正在使用 100 的批量大小来加载数据。但是我不确定我是一次性加载整个数据集还是批量加载数据。即使批量大小低至 1,它也会给我CUDA 内存不足错误。我想知道默认(如pytorch文档中所示)dataloader是否使用批量加载或一次性转储所有数据?如果它尝试转储 7.5 GB 的数据,则此错误可能是由于数据集过大造成的。

作为参考,我附上了相同的代码-

class Image(Dataset):
    def __init__(self, setname):
        csv_path = osp.join(ROOT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in lines:
            name, wnid = l.split(',')
            path = osp.join(ROOT_PATH, 'images', name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append(path)
            label.append(lb)

        self.data = data
        self.label = label

        self.transform = transforms.Compose([
            transforms.Resize(140),
            transforms.CenterCrop(140),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, label = self.data[i], self.label[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, label

以上是数据加载器部分。在主脚本中,它被称为如下 -

for epoch in range(1, args.max_epoch + 1):
    lr_scheduler.step()

    model.train()

    for i, batch in enumerate(train_loader, 1):
        data, _ = [_.cuda() for _ in batch]
        p = args.shot * args.train_way
        data_shot, data_query = data[:p], data[p:]

        proto = model(data_shot)

        label = torch.arange(args.train_way).repeat(args.query)
        label = label.type(torch.cuda.LongTensor)

        logits = euclidean_metric(model(data_query), proto)
        loss = F.cross_entropy(logits, label)
        tl.add(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

标签: pytorchgpuconv-neural-network

解决方案


推荐阅读