python - PyTorch - 丢弃数据加载器批处理
问题描述
我有一个Dataset
从大文件加载数据的自定义。有时,加载的数据是空的,我不想将它们用于训练。
在Dataset
我有:
def __getitem__(self, i):
(x, y) = self.getData(i) #getData loads data and handles problems
return (x, y)
在错误数据返回的情况下(None, None)
(x
并且y
两者都是None
)。但是,它后来失败了DataLoader
,我无法完全跳过这批。我将批量大小设置为1
.
trainLoader = DataLoader(trainDataset, batch_size=1, shuffle=False)
for x_batch, y_batch in trainLoader:
#process and train
解决方案
您可以实现自定义IterableDataset
并定义 a __next__
,__iter__
这将跳过您的getData
函数引发错误的任何实例:
这是使用虚拟数据的可能实现:
class DS(IterableDataset):
def __init__(self):
self.data = torch.randint(0,3,(20,))
self._i = -1
def getData(self, index):
x = self.data[index]
if x == 0:
raise ValueError
return x
def __iter__(self):
return self
def __next__(self):
self._i += 1
if self._i == len(self.data): # out of instances
self._i = -1 # reset the iterable
raise StopIteration # stop the iteration
try:
return self.getData(self._i)
except ValueError:
return next(self)
你会像这样使用它:
>>> trainLoader = DataLoader(DS(), batch_size=1, shuffle=False)
>>> for x in trainLoader:
... print(x)
tensor([1])
tensor([2])
tensor([2])
...
tensor([1])
tensor([1])
这里所有0
实例都在可迭代数据集中被跳过。
您可以调整这个简单的示例以满足您的需求。
推荐阅读
- python - 使用 Kubernetes python 客户端列出命名空间中的所有资源
- java - 如何将响应字符串转换为 json 对象
- ruby-on-rails - 实现文件请求计数器
- python - “S3File”对象没有“强制”属性
- ruby-on-rails - 从 Rails 模型调用 Gem 中的方法
- r - 我应该如何处理合并(完全加入)多个(> 100)CSV 文件与一个公用键但行数不一致?
- perl - 使用 Perl 将文件读入两个散列
- javascript - forloop onclick 显示文本字段
- arrays - Python:命名占位符如何组织字符串数组中的数据?
- python - 包含 json 格式列的 Dask 数据框