首页 > 解决方案 > PyTorch Sampler:如果指定,下一次迭代采样相同的子集还是另一个子集?

问题描述

如果我有如下代码:

dataset = Dataset(...)
sampler = RandomSampler(...)
dataloader = DataLoader(..., sampler=sampler)

然后每当我打电话时:

for data, label in dataloader:
   ...

data, label与上次调用相比,返回的元组是相同的子集还是不同的子集?

标签: pythonpytorch

解决方案


与上次调用相比,它是不同的子集。我在这里为您的问题修改示例:

data = torch.rand(10,1)
dataset = torch.utils.data.TensorDataset(torch.arange(len(data)),data)
index,_ = dataset[:]

sampler = torch.utils.data.RandomSampler(index)
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=3)

for i in range(2):
  for data, label in loader:
    print(data, label)
  print("------------")

推荐阅读