python - MNIST Pytorch 中的验证错误意外增加
问题描述
我对整个领域有点陌生,因此决定研究 MNIST 数据集。我几乎从https://github.com/pytorch/examples/blob/master/mnist/main.py改编了整个代码,只有一个重大变化:数据加载。我不想在 Torchvision 中使用预加载的数据集。所以我在 CSV 中使用了 MNIST。
我通过从 Dataset 继承并创建一个新的数据加载器从 CSV 文件加载数据。以下是相关代码:
mean = 33.318421449829934
sd = 78.56749081851163
# mean = 0.1307
# sd = 0.3081
import numpy as np
from torch.utils.data import Dataset, DataLoader
class dataset(Dataset):
def __init__(self, csv, transform=None):
data = pd.read_csv(csv, header=None)
self.X = np.array(data.iloc[:, 1:]).reshape(-1, 28, 28, 1).astype('float32')
self.Y = np.array(data.iloc[:, 0])
del data
self.transform = transform
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
item = self.X[idx]
label = self.Y[idx]
if self.transform:
item = self.transform(item)
return (item, label)
import torchvision.transforms as transforms
trainData = dataset('mnist_train.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))
testData = dataset('mnist_test.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))
train_loader = DataLoader(dataset=trainData,
batch_size=10,
shuffle=True,
)
test_loader = DataLoader(dataset=testData,
batch_size=10,
shuffle=True,
)
然而,这段代码给了我你在图片中看到的绝对奇怪的训练错误图,以及 11% 的最终验证错误,因为它将所有内容分类为“7”。
我设法将问题追溯到我如何规范化数据,以及如果我使用示例代码中给出的值(0.1307 和 0.3081)进行 transforms.Normalize,以及将数据读取为“uint8”类型,它工作得很好。请注意,在这两种情况下提供的数据差异非常小。对 0 到 1 的值进行 0.1307 和 0.3081 归一化与对 0 到 255 的值进行 33.31 和 78.56 归一化的效果相同。这些值甚至几乎相同(黑色像素对应于第一种情况下的 -0.4241 和 -0.4242在第二)。
如果您想查看清楚地看到此问题的 IPython Notebook,请查看https://colab.research.google.com/drive/1W1qx7IADpnn5e5w97IcxVvmZAaMK9vL3
我无法理解在这两种略有不同的数据加载方式中,是什么导致了如此巨大的行为差异。任何帮助将不胜感激。
解决方案
长话短说:您需要更改item = self.X[idx]
为item = self.X[idx].copy()
.
长话短说:T.ToTensor()
runs torch.from_numpy
,它返回一个张量,它为你的 numpy array 的内存加上别名dataset.X
。并且T.Normalize()
在原地工作,因此每次抽取样本时都会mean
减去并除以std
,从而导致数据集退化。
编辑:关于它为什么在原始 MNIST 加载器中工作,兔子洞更深。关键MNIST
在于图像被转换为 PIL.Image实例。该操作声称仅在缓冲区不连续的情况下进行复制(在我们的情况下),但在引擎盖下它检查它是否改为跨步(它是),从而复制它。所以幸运的是,默认的 torchvision 管道涉及一个副本,因此就地操作T.Normalize()
不会破坏self.data
我们MNIST
实例的内存。
推荐阅读
- azure-devops - 在自定义 Azure DevOps 任务/扩展中发送对象数组
- python - Python Socket 没有接收到发送给它的消息
- python - 如何更快地迭代二维的 Python numpy.ndarray
- python-3.x - numba-TypingError:在 nopython 模式管道中失败(步骤:nopython 前端)非精确类型 pyobject
- module - Rust 中的跨模块函数调用
- c# - .NET Core 洋葱架构中的依赖注入
- docker - 如何通过远程 docker 的代码设置 repoLayoutRef
- javascript - 互动板
- php - 使用 PHP 删除除 alt 代码之外的特殊字符
- javascript - 如何区分Dropdown item事件?