deep-learning - 在每个 epoch 之后训练 MNIST 数据集时如何输出准确度和损失
问题描述
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from tensorflow.examples.tutorials.mnist import input_data
import torch.optim as optim
import tensorflow.python.util.deprecation as deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
import matplotlib.pyplot as plt
%matplotlib inline
from plot import plot_loss_and_acc
mnist = input_data.read_data_sets("MNIST_data", one_hot=False)
batch_size = 250
epoch_num = 10
lr = 0.0001
disp_freq = 20
def next_batch(train=True):
# Reads the next batch of MNIST images and labels and returns them
if train:
batch_img, batch_label = mnist.train.next_batch(batch_size)
else:
batch_img, batch_label = mnist.test.next_batch(batch_size)
batch_label = torch.from_numpy(batch_label).long() # convert the numpy array into torch tensor
batch_label = Variable(batch_label) # create a torch variable
batch_img = torch.from_numpy(batch_img).float() # convert the numpy array into torch tensor
batch_img = Variable(batch_img) # create a torch variable
return batch_img, batch_label
class MLP(nn.Module):
def __init__(self, n_features, n_classes):
super(MLP, self).__init__()
self.layer1 = nn.Linear(n_features, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_classes)
def forward(self, x, training=True):
# a neural network with 2 hidden layers
# x -> FC -> relu -> dropout -> FC -> relu -> dropout -> FC -> output
x = F.relu(self.layer1(x))
x = F.dropout(x, 0.5, training=training)
x = F.relu(self.layer2(x))
x = F.dropout(x, 0.5, training=training)
x = self.layer3(x)
return x
def predict(self, x):
# a function to predict the labels of a batch of inputs
x = F.softmax(self.forward(x, training=False))
return x
def accuracy(self, x, y):
# a function to calculate the accuracy of label prediction for a batch of inputs
# x: a batch of inputs
# y: the true labels associated with x
prediction = self.predict(x)
maxs, indices = torch.max(prediction, 1)
acc = 100 * torch.sum(torch.eq(indices.float(), y.float()).float())/y.size()[0]
print(acc.data)
return acc.data
# define the neural network (multilayer perceptron)
net = MLP(784, 10)
# calculate the number of batches per epoch
batch_per_ep = mnist.train.num_examples // batch_size
# define the loss (criterion) and create an optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
print(' ')
print("__________Training__________________")
xArray = []
yLoss = []
yAcc = []
for ep in range(epoch_num): # epochs loop
for batch_n in range(batch_per_ep): # batches loop
features, labels = next_batch()
# Reset gradients
optimizer.zero_grad()
# Forward pass
output = net(features)
loss = criterion(output, labels)
# Backward pass and updates
loss.backward() # calculate the gradients (backpropagation)
optimizer.step() # update the weights
if batch_n % disp_freq == 0:
print('epoch: {} - batch: {}/{} '.format(ep, batch_n, batch_per_ep))
xArray.append(ep)
yLoss.append(loss.data)
#yAcc.append(acc.data)
print('loss: ', loss.data)
print('__________________________________')
# test the accuracy on a batch of test data
features, labels = next_batch(train=False)
print("Result")
print('Test accuracy: ', net.accuracy(features, labels))
print('loss: ', loss.data)
accuracy = net.accuracy(features, labels)
#Loss Plot
# plotting the points
plt.plot(xArray, yLoss)
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel('loss')
# giving a title to my graph
plt.title('Loss Plot')
# function to show the plot
plt.show()
#Accuracy Plot
# plotting the points
plt.plot(xArray, yAcc)
# naming the x axis
plt.xlabel('epoch')
# naming the y axis
plt.ylabel(' accuracy')
# giving a title to my graph
plt.title('Accuracy Plot ')
# function to show the plot
plt.show()
我想显示我的训练数据集的准确性。我已经设法显示和绘制损失,但我没有设法做到准确。我知道我缺少 1 或 2 行代码,但我不知道该怎么做。
我的意思是,如果我可以像损失一样显示每个时期的准确性,我可以自己进行绘图。
解决方案
嗨,将此代码替换print('epoch: {} - batch: {}/{} '.format(ep, batch_n, batch_per_ep))
为
print('epoch: {} - batch: {}/{} - accuracy: {}'.format(ep, batch_n, batch_per_ep, net.accuracy(features,labels)))
希望这可以帮助。
推荐阅读
- spring - 在运行时禁用更改 Spring Boot 日志记录级别
- java - 多个文本输入和按钮 Java
- javascript - 做一个 API 顶点项目,无法从文本输入中检索值
- java - 无法通过 Java 中的 Selenium 定位父 iframe 中的子 iframe 元素
- sql - 将数据插入oracle sql live
- ruby - 红宝石块哈希输出的差异
- javascript - React 中的 useState() 是什么?
- hadoop - 插入 hive orc 分区表时出现运行时异常
- c# - 如何安装 C# 客户端库,尤其是 Google Cloud Text to Speech?
- amazon-web-services - ValueError:当在 Sagemaker 中使用 Gunicorn、Flask 和 Keras 托管模型时,张量不是该图的元素