python - mnist数据的pytorch分类器不起作用
问题描述
我是 pytorch 的新手,我尝试用 mnist 数据训练一个简单的分类器。但是,我的分类器的准确率在 10% 左右,我尝试了几种方法来调整网络,但都失败了,分类器的输出标签总是一样的,全 0,或全 7,或全 6。请告诉我有什么问题代码。(我知道我应该使用DataLoader,我稍后会看,现在我只是想让分类器的准确性看起来不错)
# coding=utf-8
# 数据为data中的handwritten_digit
import struct
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
data_folder = '../data/handwritten_digit/'
dt = torch.get_default_dtype()
train_label_file = 'train-labels-idx1-ubyte'
train_img_file = 'train-images-idx3-ubyte'
test_img_file = 't10k-images-idx3-ubyte'
test_label_file = 't10k-labels-idx1-ubyte'
model_path = './handwritten_digit_recognition_net3.pth'
def timer(func):
def cal_time(*args, **kw):
start_time = time.time()
out = func(*args, **kw)
end_time = time.time()
print('函数 ', func.__name__, ' 运行耗时', end_time-start_time, '秒', sep = '')
return out
return cal_time
def read_imgs(file):
with open(data_folder+file, 'rb') as frb:
# 先读取meta
magic_num, img_num, row_num, col_num = struct.unpack('>IIII', frb.read(16))
# print(magic_num, img_num, row_num, col_num)
# img = np.fromfile(frb, dtype = np.uint8, count = row_num*col_num).reshape(row_num, col_num)
# print(img, img.shape, 'img')
imgs = np.fromfile(frb, dtype = np.uint8).reshape(img_num, row_num, col_num)
# imgs = np.fromfile(frb, dtype = np.uint8, count = row_num*col_num*img_num).reshape(img_num, row_num, col_num)
return torch.from_numpy(imgs).type(dt).unsqueeze(1).unsqueeze(1)
def read_labels(file):
with open(data_folder+file, 'rb') as frb:
# 先读取meta
magic_num, label_num = struct.unpack('>II', frb.read(8))
# print(magic_num, label_num)
labels = np.fromfile(frb, dtype = np.uint8)
return torch.from_numpy(labels).type(dt)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 12, 5)
self.conv2 = nn.Conv2d(12, 12, 5)
self.pool = nn.MaxPool2d(2, 2)
self.linear1 = nn.Linear(12*16, 30)
self.linear2 = nn.Linear(30, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.view(-1, 12*16)
# print(x.size(), 'x.size()')
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x
@timer
def train_and_save_net():
train_imgs = read_imgs(train_img_file)
train_labels = read_labels(train_label_file)
test_imgs = read_imgs(test_img_file)
test_labels = read_labels(test_label_file)
# label = torch.zeros(1, 10)
# label[0][int(train_labels[0])] = 1
# print(label)
# print(train_labels[0])
# return
net = Net()
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
print('Start Training')
sum_loss = 0
for i, img in enumerate(train_imgs):
optimizer.zero_grad()
predicted = net(img)
# label = torch.zeros(1, 10)
# label[0][int(train_labels[i])] = 1
label = torch.tensor([train_labels[i]], dtype = torch.long)
# print(predicted, predicted.size(), 'predicted')
# print(label, label.size(), 'label')
loss = criterion(predicted, label)
loss.backward()
optimizer.step()
sum_loss += loss.item()
if i % 2000 == 1999:
print('已经训练了', i+1, '张图片,', '完成进度:', '%.2f'%((i+1)/len(train_labels)*100), '%', sep = '')
print('loss为:', sum_loss/2000)
sum_loss = 0
print('End Training')
torch.save(net.state_dict(), model_path)
print('End Saving Net Parameters')
def load_net():
net = Net()
net.load_state_dict(torch.load(model_path))
return net
@timer
def evaluate():
train_imgs = read_imgs(train_img_file)
train_labels = read_labels(train_label_file)
test_imgs = read_imgs(test_img_file)
test_labels = read_labels(test_label_file)
net = load_net()
# 直观感受
for i in range(5):
img = train_imgs[i]
# plt.imshow(img.squeeze(), cmap = 'gray')
# plt.show()
predicted_vector = net(img)
_, predicted = torch.max(predicted_vector, 1)
predicted = predicted.item()
print('预测的分类是:', predicted, ',实际的分类是:', int(train_labels[i].item()), sep = '')
# 训练集精度
total = len(train_labels)
correct = 0
for i in range(len(train_labels)):
img = train_imgs[i]
predicted_vector = net(img)
_, predicted = torch.max(predicted_vector, 1)
label = int(train_labels[i].item())
if predicted == label:
correct += 1
print('训练集上的准确率为:', '%.2f'%(correct/total*100), '%', sep = '')
total = len(test_labels)
correct = 0
pre_arr = []
for i in range(len(test_labels)):
img = test_imgs[i]
predicted_vector = net(img)
_, predicted = torch.max(predicted_vector, 1)
label = int(test_labels[i].item())
pre_arr.append(predicted)
if predicted == label:
correct += 1
print('测试集上的准确率为:', '%.2f'%(correct/total*100), '%', sep = '')
print('模型判断为0的个数/总判断数 为:', pre_arr.count(0), '/', len(pre_arr), sep = '')
@timer
def test():
predicted_vector = torch.randn(1,10)
_, predicted = torch.max(predicted_vector, 1)
print(predicted.item())
if __name__ == '__main__':
train_and_save_net()
# test()
evaluate()
解决方案
嗯,我好像弄清楚问题出在哪里了,我把学习率从 1e-3 改成了 1e-4,然后准确率就达到了 97% 左右……
推荐阅读
- ruby - 如何以无人身份运行 ruby
- java - Spring Boot 应用程序中的多个 Spring Data JPA 模块(非 Spring Boot)依赖项?
- c++ - C++:为什么我可以将指针值成员变量从 const 成员函数传递给采用非 const 指针参数的外部函数?
- python - Python 自动化 - 自动将文件移动到另一个文件夹 - 代码未执行
- sql - 关于SQL Server中同一列的除法问题
- python - aws-xray-sdk-python:手动创建的段未显示在 AWS 服务地图上
- reactjs - 获取 Redux 中所有相关方法的完整调用堆栈
- c# - 通过代码创建和执行的 SSIS 包需要在多次运行后重新启动服务
- php - 从 csv 文件读取数据并使用 phpspreadsheet 写入 excel
- jquery - 获取ajax加载元素的值