首页 > 解决方案 > mnist数据的pytorch分类器不起作用

问题描述

我是 pytorch 的新手,我尝试用 mnist 数据训练一个简单的分类器。但是,我的分类器的准确率在 10% 左右,我尝试了几种方法来调整网络,但都失败了,分类器的输出标签总是一样的,全 0,或全 7,或全 6。请告诉我有什么问题代码。(我知道我应该使用DataLoader,我稍后会看,现在我只是想让分类器的准确性看起来不错)

# coding=utf-8
# 数据为data中的handwritten_digit

import struct
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

data_folder = '../data/handwritten_digit/'
dt = torch.get_default_dtype()
train_label_file = 'train-labels-idx1-ubyte'
train_img_file = 'train-images-idx3-ubyte'
test_img_file = 't10k-images-idx3-ubyte'
test_label_file = 't10k-labels-idx1-ubyte'
model_path = './handwritten_digit_recognition_net3.pth'

def timer(func):
    def cal_time(*args, **kw):
        start_time = time.time()
        out = func(*args, **kw)
        end_time = time.time()
        print('函数 ', func.__name__, ' 运行耗时', end_time-start_time, '秒', sep = '')
        return out
    return cal_time

def read_imgs(file):
    with open(data_folder+file, 'rb') as frb:
        # 先读取meta
        magic_num, img_num, row_num, col_num = struct.unpack('>IIII', frb.read(16))
        # print(magic_num, img_num, row_num, col_num)
        # img = np.fromfile(frb, dtype = np.uint8, count = row_num*col_num).reshape(row_num, col_num)
        # print(img, img.shape, 'img')
        imgs = np.fromfile(frb, dtype = np.uint8).reshape(img_num, row_num, col_num)
        # imgs = np.fromfile(frb, dtype = np.uint8, count = row_num*col_num*img_num).reshape(img_num, row_num, col_num)
    return torch.from_numpy(imgs).type(dt).unsqueeze(1).unsqueeze(1)

def read_labels(file):
    with open(data_folder+file, 'rb') as frb:
        # 先读取meta
        magic_num, label_num = struct.unpack('>II', frb.read(8))
        # print(magic_num, label_num)
        labels = np.fromfile(frb, dtype = np.uint8)
    return torch.from_numpy(labels).type(dt)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.conv2 = nn.Conv2d(12, 12, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.linear1 = nn.Linear(12*16, 30)
        self.linear2 = nn.Linear(30, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 12*16)
        # print(x.size(), 'x.size()')
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x



@timer
def train_and_save_net():
    train_imgs = read_imgs(train_img_file)
    train_labels = read_labels(train_label_file)
    test_imgs = read_imgs(test_img_file)
    test_labels = read_labels(test_label_file)

    # label = torch.zeros(1, 10)

    # label[0][int(train_labels[0])] = 1
    # print(label)
    # print(train_labels[0])
    # return

    net = Net()
    # criterion = nn.MSELoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)

    print('Start Training')
    sum_loss = 0
    for i, img in enumerate(train_imgs):
        optimizer.zero_grad()
        predicted = net(img)
        # label = torch.zeros(1, 10)
        # label[0][int(train_labels[i])] = 1
        label = torch.tensor([train_labels[i]], dtype = torch.long)
        # print(predicted, predicted.size(), 'predicted')
        # print(label, label.size(), 'label')
        loss = criterion(predicted, label)
        loss.backward()
        optimizer.step()

        sum_loss += loss.item()
        if i % 2000 == 1999:
            print('已经训练了', i+1, '张图片,', '完成进度:', '%.2f'%((i+1)/len(train_labels)*100), '%', sep = '')
            print('loss为:', sum_loss/2000)
            sum_loss = 0
    print('End Training')

    torch.save(net.state_dict(), model_path)
    print('End Saving Net Parameters')


def load_net():
    net = Net()
    net.load_state_dict(torch.load(model_path))
    return net

@timer
def evaluate():

    train_imgs = read_imgs(train_img_file)
    train_labels = read_labels(train_label_file)
    test_imgs = read_imgs(test_img_file)
    test_labels = read_labels(test_label_file)

    net = load_net()

    # 直观感受
    for i in range(5):
        img = train_imgs[i]
        # plt.imshow(img.squeeze(), cmap = 'gray')
        # plt.show()
        predicted_vector = net(img)
        _, predicted = torch.max(predicted_vector, 1)
        predicted = predicted.item()
        print('预测的分类是:', predicted, ',实际的分类是:', int(train_labels[i].item()), sep = '')

    # 训练集精度
    total = len(train_labels)
    correct = 0
    for i in range(len(train_labels)):
        img = train_imgs[i]
        predicted_vector = net(img)
        _, predicted = torch.max(predicted_vector, 1)
        label = int(train_labels[i].item())
        if predicted == label:
            correct += 1
    print('训练集上的准确率为:', '%.2f'%(correct/total*100), '%', sep = '')


    total = len(test_labels)
    correct = 0
    pre_arr = []
    for i in range(len(test_labels)):
        img = test_imgs[i]
        predicted_vector = net(img)
        _, predicted = torch.max(predicted_vector, 1)
        label = int(test_labels[i].item())
        pre_arr.append(predicted)
        if predicted == label:
            correct += 1
    print('测试集上的准确率为:', '%.2f'%(correct/total*100), '%', sep = '')
    print('模型判断为0的个数/总判断数 为:', pre_arr.count(0), '/', len(pre_arr), sep = '')

@timer
def test():
    predicted_vector = torch.randn(1,10)
    _, predicted = torch.max(predicted_vector, 1)
    print(predicted.item())

if __name__ == '__main__':
    train_and_save_net()
    # test()
    evaluate()

标签: pythonmachine-learningdeep-learningneural-networkpytorch

解决方案


嗯,我好像弄清楚问题出在哪里了,我把学习率从 1e-3 改成了 1e-4,然后准确率就达到了 97% 左右……


推荐阅读