首页 > 解决方案 > 如何使用torchvision MNIST数据集中的一个原始批次和一个增强批次进行改组,“对齐”批次样本,但批次大小不同?

问题描述

我想为 torchvision MNIST 数据集实现这种情况,加载数据DataLoader

batch A (unaugmented images): 5, 0, 4, ...
batch B (augmented images): 5*, 5+, 5-, 0*, 0+, 0-, 4*, 4+, 4-, ...

...其中对于 A 的每个图像,批次 B 中有 3 个增强。len(B) = 3*len(A) 相应地。这些批次应在一次迭代中使用,以将批次 A 的原始图像与批次 B 中增强的图像进行比较,以构建损失。

class MyMNIST(Dataset):

def __init__(self, mnist_dir, train, augmented, transform=None, repeat=1):

    self.mnist_dir = mnist_dir
    self.train = train
    self.augmented = augmented
    self.repeat = repeat
    self.transform = transform
    self.dataset = None

    if augmented and train:
        self.dataset = datasets.MNIST(self.mnist_dir, train=train, download=True, transform=transform)
        self.dataset.data = torch.repeat_interleave(self.dataset.data, repeats=self.repeat, dim=0)
        self.dataset.targets = torch.repeat_interleave(self.dataset.targets, repeats=self.repeat, dim=0)
    elif augmented and not train:
        raise Exception("Test set should not be augmented.")
    else:
        self.dataset = datasets.MNIST(MNIST_DIR, train=train, download=True, transform=transform)

使用这个类,我想初始化两个不同的数据加载器:

orig_train = MyMNIST(MNIST_DIR, train=True, augmented=False, transform=orig_transforms)
orig_train_loader = torch.utils.data.DataLoader(orig_train.dataset, batch_size=100, shuffle=True)

aug_train = MyMNIST(MNIST_DIR, train=True, augmented=True, transform=aug_transforms, repeat=3)
aug_train_loader = torch.utils.data.DataLoader(aug_train.dataset, batch_size=300, shuffle=True)

我现在的问题是,我还需要在 A 和 B 之间的顺序保持相关的情况下对每次迭代进行洗牌。上面的代码是不可能的,因为两者都会DataLoader产生不同的订单。因此,我尝试使用单个DataLoader并手动复制重复的批次:

for batch_no, (images, labels) in enumerate(orig_train_loader):
    repeat_images = torch.repeat_interleave(images, 3, dim=0)

这样,我得到了批处理 B ( repeat_images) 的顺序,但现在我错过了需要在批处理/迭代中应用的转换。这似乎不是 Pytorch 的范式,至少我没有找到办法做到这一点。

如果有人可以帮助我,我会很高兴 - 我对 Pytorch(以及 stackoverflow)还很陌生,所以也欢迎批评我的整个方法、可能出现的性能问题等。

非常感谢!

标签: pythondatasetpytorchtorchvisiondataloader

解决方案


推荐阅读