首页 > 解决方案 > PyTorch DataLoader 随机播放

问题描述

我做了一个实验,并没有得到我期望的结果。

对于第一部分,我正在使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

在训练我的模型之前,我保存trainloader.dataset.targets到变量atrainloader.dataset.data变量中。b然后,我使用trainloader.
训练完成后,我保存trainloader.dataset.targets到变量c中,然后保存到trainloader.dataset.data变量中d。最后,我检查了a == candb == d他们都给了True,这是意料之中的,因为 shuffle 的参数DataLoaderFalse

对于第二部分,我正在使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

在训练我的模型之前,我保存trainloader.dataset.targets到变量etrainloader.dataset.data变量中。f然后,我使用trainloader. 训练完成后,我保存trainloader.dataset.targets到变量g中,然后保存到trainloader.dataset.data变量中h。我期待e == gf == h成为两者,Falseshuffle=True他们True再次给予。我从DataLoader类的定义中遗漏了什么?

标签: pythonneural-networkpytorchshuffletraining-data

解决方案


我相信直接存储在 trainloader.dataset.data 或 .target 中的数据不会被打乱,数据只有在 DataLoader 被称为生成器或迭代器时才会打乱

您可以通过执行 next(iter(trainloader)) 几次而不进行改组和改组来检查它,它们应该会给出不同的结果

import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

这将给出:

tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

然而,存储在 data 和 target 中的数据和标签是一个固定列表,由于您尝试直接访问它,它们不会被打乱。


推荐阅读