首页 > 解决方案 > 如何在 PyTorch 中为非图像数据创建小批量?

问题描述

我想加载我的训练和测试数据

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, ))])

trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)


testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

我看到图像数据的实现有没有办法以类似的方式加载非图像数据?

标签: pytorch

解决方案


我们可以torch.utils.data通过以下步骤使用模块,

  1. 通过继承创建一个Dataset类来加载自定义数据torch.utils.data.Dataset

  2. 通过将数据传递给自定义 Dataset 类的实例来创建数据集对象

  3. 用于torch.utils.data.DataLoader加载数据集并获取批次

假设您已从目录中加载数据,在训练和测试 numpy 数组中,您可以从torch.utils.data.Dataset类继承来创建数据集对象

class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
        self.x = x
        self.y = y


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

然后,创建您的数据集对象

traindata = MyDataset(train_x, train_y)

最后,用于DataLoader创建您的小批量

trainloader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)

推荐阅读