首页 > 解决方案 > 如何使用 PyTorch 从本地目录导入 MNIST 数据集

问题描述

我正在编写MNIST database of handwritten digitsPyTorch 中一个众所周知的问题的代码。我下载了训练和测试数据集(从主网站),包括标记的数据集。数据集格式为t10k-images-idx3-ubyte.gzextract 和之后t10k-images-idx3-ubyte。我的数据集文件夹看起来像

MINST
 Data
  train-images-idx3-ubyte.gz
  train-labels-idx1-ubyte.gz
  t10k-images-idx3-ubyte.gz
  t10k-labels-idx1-ubyte.gz

现在,我编写了一个代码来加载如下数据

def load_dataset():
    data_path = "/home/MNIST/Data/"
    xy_trainPT = torchvision.datasets.ImageFolder(
        root=data_path, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        xy_trainPT, batch_size=64, num_workers=0, shuffle=True
    )
    return train_loader

我的代码正在显示Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

我该如何解决这个问题,并且我还想检查我的图像是否已从数据集中加载(只有一个图形包含前 5 个图像)?

标签: python-3.xmachine-learningdeep-learningpytorch

解决方案


通过 Python 从 .idx3-ubyte 文件或 GZIP 中读取此提取图像

更新

您可以使用此格式导入数据

xy_trainPT = torchvision.datasets.MNIST(
    root="~/Handwritten_Deep_L/",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
)

现在,download=True首先您的代码将检查根目录(您给定的路径)是否包含任何数据集。

如果no然后数据集将从网络下载。

如果yes此路径已包含数据集,则您的代码将使用现有数据集运行,并且不会从 Internet 下载。

可以查看,先给个路径without any dataset(数据会从网上下载),再给个路径,which already contains dataset数据不会下载。


推荐阅读