首页 > 解决方案 > MNIST Pytorch 中的验证错误意外增加

问题描述

我对整个领域有点陌生,因此决定研究 MNIST 数据集。我几乎从https://github.com/pytorch/examples/blob/master/mnist/main.py改编了整个代码,只有一个重大变化:数据加载。我不想在 Torchvision 中使用预加载的数据集。所以我在 CSV 中使用了 MNIST

我通过从 Dataset 继承并创建一个新的数据加载器从 CSV 文件加载数据。以下是相关代码:

mean = 33.318421449829934
sd = 78.56749081851163
# mean = 0.1307
# sd = 0.3081
import numpy as np
from torch.utils.data import Dataset, DataLoader

class dataset(Dataset):
    def __init__(self, csv, transform=None):
        data = pd.read_csv(csv, header=None)
        self.X = np.array(data.iloc[:, 1:]).reshape(-1, 28, 28, 1).astype('float32')
        self.Y = np.array(data.iloc[:, 0])

        del data
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        item = self.X[idx]
        label = self.Y[idx]

        if self.transform:
            item = self.transform(item)

        return (item, label)

import torchvision.transforms as transforms
trainData = dataset('mnist_train.csv', transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (sd,))
]))
testData = dataset('mnist_test.csv', transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (sd,))
]))

train_loader = DataLoader(dataset=trainData,
                         batch_size=10, 
                         shuffle=True,
                         )
test_loader = DataLoader(dataset=testData, 
                        batch_size=10, 
                        shuffle=True,
                        )

然而,这段代码给了我你在图片中看到的绝对奇怪的训练错误图,以及 11% 的最终验证错误,因为它将所有内容分类为“7”。 验证错误图

我设法将问题追溯到我如何规范化数据,以及如果我使用示例代码中给出的值(0.1307 和 0.3081)进行 transforms.Normalize,以及将数据读取为“uint8”类型,它工作得很好。请注意,在这两种情况下提供的数据差异非常小。对 0 到 1 的值进行 0.1307 和 0.3081 归一化与对 0 到 255 的值进行 33.31 和 78.56 归一化的效果相同。这些值甚至几乎相同(黑色像素对应于第一种情况下的 -0.4241 和 -0.4242在第二)。

如果您想查看清楚地看到此问题的 IPython Notebook,请查看https://colab.research.google.com/drive/1W1qx7IADpnn5e5w97IcxVvmZAaMK9vL3

我无法理解在这两种略有不同的数据加载方式中,是什么导致了如此巨大的行为差异。任何帮助将不胜感激。

标签: pythondeep-learningpytorchmnist

解决方案


长话短说:您需要更改item = self.X[idx]item = self.X[idx].copy().

长话短说:T.ToTensor()runs torch.from_numpy,它返回一个张量,它为你的 numpy array 的内存加上别名dataset.X。并且T.Normalize() 在原地工作,因此每次抽取样本时都会mean减去并除以std,从而导致数据集退化。

编辑:关于它为什么在原始 MNIST 加载器中工作,兔子洞更深。关键MNIST在于图像被转换​​ PIL.Image实例。该操作声称仅在缓冲区不连续的情况下进行复制(在我们的情况下),但在引擎盖下它检查它是否改为跨步(它是),从而复制它。所以幸运的是,默认的 torchvision 管道涉及一个副本,因此就地操作T.Normalize()不会破坏self.data我们MNIST实例的内存。


推荐阅读