首页 > 解决方案 > Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,数据中缺少批量大小

问题描述

我使用 random_split() 将我的数据划分为训练和测试,我观察到如果在创建数据加载器后进行随机拆分,则在从数据加载器获取一批数据时会丢失批量大小。

import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split

# Normalize the data
transform_image = transforms.Compose([
  transforms.Resize((240, 320)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

data = '/data/imgs/train'

def load_dataset():
  data_path = data
  main_dataset = datasets.ImageFolder(
    root = data_path,
    transform = transform_image
  )

  loader = torch.utils.data.DataLoader(
    dataset = main_dataset,
    batch_size= 64,
    num_workers = 0,
    shuffle= True
  )

  # Dataset has 22424 data points
  trainloader, testloader = random_split(loader.dataset, [21000, 1424])

  return trainloader, testloader

trainloader, testloader = load_dataset()

现在从训练和测试加载器中获取一批图像:

images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)

# %%
images_test, labels_test = next(iter(testloader))
images_test.shape

# %%
len(testloader)

我得到的输出没有训练或测试批次的批量大小。输出调光应该是 [batch x channel x H x W] 但我得到 [channel x H x W]。

输出:

在此处输入图像描述

但是,如果我从数据集创建拆分,然后使用拆分创建两个数据加载器,我会在输出中获得批量大小。

def load_dataset():
    data_path = data
    main_dataset = datasets.ImageFolder(
      root = data_path,
      transform = transform_image
    )
    # Dataset has 22424 data points
    train_data, test_data = random_split(main_dataset, [21000, 1424])

    trainloader = torch.utils.data.DataLoader(
      dataset = train_data,
      batch_size= 64,
      num_workers = 0,
      shuffle= True
    )

    testloader = torch.utils.data.DataLoader(
      dataset = test_data,
      batch_size= 64,
      num_workers= 0,
      shuffle= True
    )

    return trainloader, testloader

trainloader, testloader = load_dataset()

在运行相同的 4 个命令以获得单个训练和测试批次时:

在此处输入图像描述

第一种方法是错误的吗?虽然长度显示数据已经被分割。那么为什么我看不到批量大小呢?

标签: pythonpython-3.xdeep-learningpytorch

解决方案


第一种方法是错误的。

只有DataLoader实例返回成批的项目。类似的Dataset情况没有。

当你打电话给make_split你时loader.dataset,它只是对main_dataset(不是 a DataLoader)的引用。结果是trainloaderand testloaderare Datasets not DataLoaders。事实上loaderDataLoader当您从load_dataset.

第二个版本是您应该做的以获得两个单独DataLoader的 s。


推荐阅读