首页 > 解决方案 > PyTorch DataLoader:每次调用时都遍历数据的子集而不是整个数据

问题描述

假设我有一个火炬dataloader = DataLoader(...)对象。每当我调用函数时,我都不想遍历整个数据集for data, label in dataloader:,所以目前我使用:

dataloader = DataLoader(...)
iter_dataloader = iter(dataloader)
batch = iter_dataloader.next()  # Set the first batch

def train_batch():
   data, label = batch
   prediction = model(data)
   # Do fancy things here
   try: 
      batch = iter_dataloader.next()  # Load the next batch
   except:
      iter_dataloader = iter(dataloader)  # if the iterator object reaches the end, reset the dataloader
      batch = iter_dataloader.next()  

for _ in range(N):
   train_batch()  # This function is called multiple times

对于 的每次调用train_batch(),我从数据集中获取一个批次,训练模型,然后加载下一个批次。如果没有剩余批次,我将重置 DataLoader 对象。

现在我的问题:

  1. 有没有办法让代码更清晰?也就是说,我不想使用iterandnext方法。每次我调用它时,它都会从中自动采样一批,并在到达末尾时自动重置。我听说过Sampler,但我没有使用它。
  2. 上面的扩展:代替批次,我可以使用K批次或1/K我正在使用的数据集大小吗?
  3. 我想要三种抽样方法:(1)按顺序抽样批次(无随机播放),(2)随机抽样(有和没有替换 - 随机抽样),以及(3)从中抽样,使标签相等。有没有办法做到这一点?

标签: pythonpytorch

解决方案


推荐阅读