首页 > 解决方案 > 如何在 PyTourch 中创建平衡循环迭代器?

问题描述

假设我有 2 节课。对于一个,我只有 17 个样本,另一个是 83 个。我希望每个 epoch 的每个类始终拥有相同数量的数据(在这种情况下意味着 17 x 17)。另外,我想在班级中滑动采样一个窗口,每个时期都有更多数据(前 17 个,下一个 17,...)。

目前我有一个这样的循环采样迭代器:

class CyclicIterator:
    def __init__(self, loader, sampler):
        self.loader = loader
        self.sampler = sampler
        self.epoch = 0
        self._next_epoch()

    def _next_epoch(self):
        self.iterator = iter(self.loader)
        self.epoch += 1

    def __len__(self):
        return len(self.loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            return next(self.iterator)
        except StopIteration:
            self._next_epoch()
            return next(self.iterator)

我想知道如何强制每个类别的所有样本在每个时期都具有相同的数量?

标签: pythonpytorch

解决方案


对于平衡批次,这意味着每个批次中每个类别的样本数量相等(或接近相等),有一些方法:

-过采样(使较小的类过采样,直到达到最大样本数)。在这种方法中,您可以使用以下代码:

https://github.com/galatolofederico/pytorch-balanced-batch

- 欠采样(提供基于最小类别编号的所有类别的样本数量)。根据我的经验,下面的函数确实像使用 PyTorch 库:

torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

其中 weights 是每个样本的概率,它取决于您拥有的每个类别的样本数量,例如,如果您的数据很简单,因为 data = [0, 1, 0, 0, 1], class '0' count is 3,并且“1”类计数为 2,因此权重向量为 [1/3, 1/2, 1/3, 1/3, 1/2]。有了它,您可以调用 WeightedRamdomSampler,它会为您服务。您需要在 Dataloader 中调用它。设置它的代码是:

sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
train_dataloader = DataLoader(dataset_train, batch_size=mini_batch,
                              sampler=sampler, shuffle=False,
                              num_workers=1)

推荐阅读