首页 > 解决方案 > 如何通过小数据集进行采样以进行比数据大小更多的迭代?

问题描述

我有一个小数据集和一个大数据集,它们表示两个单独的类。我正在训练的网络是风格迁移,所以我需要每个班级的一张图像才能继续训练。但是,一旦较小的数据集用完,训练就会停止。如何保持从小数据集中随机抽样超出其大小?

我试过RandomSampler()了,但没有用。这是我的小数据集代码:

sampler = RandomSampler(self)
dataloader = DataLoader(self, batch_size=26, shuffle=False, sampler=sampler)
while True:
    for data in dataloader:
        yield data

我也尝试过iterator.cycle,但这也没有帮助。

loader = iter(cycle(self.dataset.gen(attribute_id, True)))
A, y_A = next(loader)
B, y_B = next(self.dataset.gen(attribute_id, False))

标签: iteratordeep-learningdatasetpytorch

解决方案


你的想法RandomSampler并不遥远。有一个采样器叫做SubsetRandomSampler. 虽然一个子集通常小于整个集合,但情况并非如此。

假设您的较小数据集有A条目,而您的第二个数据集有B. 您可以定义您的索引:

indices = np.random.randint(0, A, B)   
sampler = torch.utils.data.sampler.SubsetRandomSampler(indices)

这会在对较小数据集有效的范围内生成B索引。

测试:

loader = torch.utils.data.DataLoader(set_A, batch_size=1, sampler=sampler)
print(len(loader)) # B

推荐阅读