首页 > 解决方案 > PyTorch - 丢弃数据加载器批处理

问题描述

我有一个Dataset从大文件加载数据的自定义。有时,加载的数据是空的,我不想将它们用于训练。

Dataset我有:

   def __getitem__(self, i):
       (x, y) = self.getData(i) #getData loads data and handles problems     
       return (x, y)

在错误数据返回的情况下(None, None)x并且y两者都是None)。但是,它后来失败了DataLoader,我无法完全跳过这批。我将批量大小设置为1.

trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
    #process and train

标签: pythonpytorchdataloader

解决方案


您可以实现自定义IterableDataset并定义 a __next____iter__这将跳过您的getData函数引发错误的任何实例:

这是使用虚拟数据的可能实现:

class DS(IterableDataset):
    def __init__(self):
        self.data = torch.randint(0,3,(20,))
        self._i = -1

    def getData(self, index):
        x = self.data[index]
        if x == 0:
            raise ValueError
        return x

    def __iter__(self):
        return self

    def __next__(self):
        self._i += 1
        if self._i == len(self.data):  # out of instances
            self._i = -1               # reset the iterable
            raise StopIteration        # stop the iteration
        try:
            return self.getData(self._i)
        except ValueError:
            return next(self)

你会像这样使用它:

>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
...    print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])

这里所有0实例都在可迭代数据集中被跳过。

您可以调整这个简单的示例以满足您的需求。


推荐阅读