首页 > 解决方案 > 如何将 MNIST 图像加载到 Pytorch DataLoader 中?

问题描述

用于数据加载和处理的 pytorch 教程非常针对一个示例,有人可以帮助我了解更通用的简单图像加载功能应该是什么样的吗?

教程: http: //pytorch.org/tutorials/beginner/data_loading_tutorial.html

我的数据:

我在以下文件夹结构中有 MINST 数据集作为 jpg 的。(我知道我可以只使用数据集类,但这纯粹是为了看看如何在没有 csv 或复杂功能的情况下将简单图像加载到 pytorch 中)。

文件夹名称是标签,图像是 28x28 png 的灰度,不需要转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

标签: pythonpytorch

解决方案


这是我为 pytorch 0.4.1 所做的(应该在 1.3 中仍然有效)

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network

推荐阅读