python - 图像分类深度学习模型显示高精度但给出错误的预测
问题描述
我已经使用 CNN 和 CIFAR 10 数据集训练了一个用于图像分类的深度学习模型。它有 10 个类,持续 50 个 epoch。训练完成后显示的准确率为 96.02%,验证准确率为 96.96%。但是,当我将图像传递给它进行检测和测试时,即使是对象很明显的图片,它也会显示错误的答案。例如:它将青蛙检测为汽车或飞机,将鹿检测为飞机或马,将鸟类检测为猫或狗。
我已经尝试用更多的神经元来训练它。然后我想可能是因为我在灰度图像上训练模型并传递彩色图像,所以我在彩色图像上训练它,它显示了相同的结果。
import os
from datetime import time
from random import shuffle
#import TensorBoard as TensorBoard
#import TensorBoard as TensorBoard
import cv2
import numpy as np
import os
from random import shuffle
from tqdm import tqdm
import tensorflow as tf
import matplotlib.pyplot as plt
import tflearn
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
TRAIN_DIR = '../input/dataset/Train'
TEST_DIR = '../input/dataset/Test'
IMG_SIZE = 32
LR = 1e-3 # 1x10^-3 or 0.003
MODEL_NAME = 'image classifier'
def create_label(image_name):
""" Create an one-hot encoded vector from image name """
word_label = image_name.split('_')[1:2]
word_label = word_label[0].split('.')[0:1]
word_label = word_label[0]
if word_label == 'cat':
return np.array([1,0,0,0,0,0,0,0,0,0])
elif word_label == 'dog':
return np.array([0,1,0,0,0,0,0,0,0,0])
elif word_label == 'automobile':
return np.array([0,0,1,0,0,0,0,0,0,0])
elif word_label == 'airplane':
return np.array([0,0,0,1,0,0,0,0,0,0])
elif word_label == 'ship':
return np.array([0,0,0,0,1,0,0,0,0,0])
elif word_label == 'frog':
return np.array([0,0,0,0,0,1,0,0,0,0])
elif word_label == 'truck':
return np.array([0,0,0,0,0,0,1,0,0,0])
elif word_label == 'bird':
return np.array([0,0,0,0,0,0,0,1,0,0])
elif word_label == 'horse':
return np.array([0,0,0,0,0,0,0,0,1,0])
elif word_label == 'deer':
return np.array([0,0,0,0,0,0,0,0,0,1])
def create_train_data():
training_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
path = os.path.join(TRAIN_DIR, img)
img_data = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
img_data = cv2.resize(img_data, (IMG_SIZE, IMG_SIZE))
training_data.append([np.array(img_data), create_label(img)])
shuffle(training_data)
np.save('train_data.npy', training_data)
return training_data
def create_test_data():
testing_data = []
for img in tqdm(os.listdir(TEST_DIR)):
path = os.path.join(TEST_DIR, img)
img_num = img.split('.')[0]
img_data = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
img_data = cv2.resize(img_data, (IMG_SIZE, IMG_SIZE))
testing_data.append([np.array(img_data), img_num])
shuffle(testing_data)
np.save('test_data.npy', testing_data)
return testing_data
# If dataset is not created:
train_data = create_train_data()
test_data = create_test_data()
# If you have already created the dataset:
#train_data = np.load('train_data.npy')
#test_data = np.load('test_data.npy')
train = train_data[:50000]
test = train_data[:10000]
X_train = np.array([i[0] for i in train]).reshape(-1, IMG_SIZE, IMG_SIZE, 1)
y_train = [i[1] for i in train]
X_test = np.array([i[0] for i in test]).reshape(-1, IMG_SIZE, IMG_SIZE, 1)
y_test = [i[1] for i in test]
# Building The Model
tf.reset_default_graph()
convnet = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 1], name='input')
convnet = conv_2d(convnet, 64, 3, activation='relu')
convnet = max_pool_2d(convnet, 3)
convnet = conv_2d(convnet, 128, 3, activation='relu')
convnet = max_pool_2d(convnet, 3)
convnet = conv_2d(convnet, 256, 3, activation='relu')
convnet = max_pool_2d(convnet, 3)
convnet = conv_2d(convnet, 128, 3, activation='relu')
convnet = max_pool_2d(convnet, 3)
convnet = conv_2d(convnet, 64, 3, activation='relu')
convnet = max_pool_2d(convnet, 3)
convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)
convnet = fully_connected(convnet, 10, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet, tensorboard_dir='log', tensorboard_verbose=2)
history = model.fit({'input': X_train}, {'targets': y_train}, n_epoch=50,
validation_set=({'input': X_test}, {'targets': y_test}),
snapshot_step=1250, show_metric=True, run_id=MODEL_NAME)
fig = plt.figure(figsize=(25, 12))
for num, data in enumerate(test_data[:50]):
img_num = data[1]
img_data = data[0]
y = fig.add_subplot(10, 10, num + 1)
orig = img_data
data = img_data.reshape(IMG_SIZE, IMG_SIZE, 1)
model_out = model.predict([data])[0]
if np.argmax(model_out) == 0:
str_label = 'Cat'
if np.argmax(model_out) == 1:
str_label = 'Dog'
if np.argmax(model_out) == 2:
str_label = 'Automobile'
if np.argmax(model_out) == 3:
str_label = 'Airplane'
if np.argmax(model_out) == 4:
str_label = 'Ship'
if np.argmax(model_out) == 5:
str_label = 'frog'
if np.argmax(model_out) == 6:
str_label = 'truck'
if np.argmax(model_out) == 7:
str_label = 'bird'
if np.argmax(model_out) == 8:
str_label = 'horse'
if np.argmax(model_out) == 9:
str_label = 'deer'
#if condition ends here
y.imshow(orig, cmap='gray')
plt.title(str_label)
y.axes.get_xaxis().set_visible(False)
y.axes.get_yaxis().set_visible(False) #for loop ends here
plt.show()
"""saving the model using tflearn"""
model.save('CNN.tfl')
# I am using the following code for testing purpose
global a
img_data1 = cv2.imread(a, cv2.IMREAD_COLOR)
img_data1 = cv2.resize(img_data1, (IMG_SIZE, IMG_SIZE))
data1 = img_data1.reshape(IMG_SIZE, IMG_SIZE, 3)
model_out = model.predict([data1])[0]
if np.argmax(model_out) == 0:
str_label = 'Cat'
if np.argmax(model_out) == 1:
str_label = 'Dog'
if np.argmax(model_out) == 2:
str_label = 'Automobile'
if np.argmax(model_out) == 3:
str_label = 'Airplane'
if np.argmax(model_out) == 4:
str_label = 'Ship'
if np.argmax(model_out) == 5:
str_label = 'frog'
if np.argmax(model_out) == 6:
str_label = 'truck'
if np.argmax(model_out) == 7:
str_label = 'bird'
if np.argmax(model_out) == 8:
str_label = 'horse'
if np.argmax(model_out) == 9:
str_label = 'deer'
“a”是传递给它的图像,代码是从 TK-Inter 上制作的 GUI 中挑选出来的,a 获取图像的路径并将其传递给模型。
我很生气地看到训练结束时显示的准确度如此之高,而我在测试时得到的模型预测仍然如此错误,即使对于明显的图像也是如此。
如果有人可以帮助我解决这个问题,我将不胜感激。
谢谢 !!
解决方案
推荐阅读
- javascript - React js-点击链接后如何防止页面重新加载。现在整个页面都在点击链接刷新
- javascript - 我的 add 函数没有给出任何错误,它也不起作用
- javascript - Angular:如果当前浏览器是 Internet Explorer,如何重定向到静态的 unsupported.html 页面?
- java - 使用通知和 getLaunchIntentForPackage 打开 Android 应用程序不会通过 LauncherActivity
- bash - 在 AWK 中替换子字符串的任何更快的方法
- r - 多列的 R/Tesseract:如何识别文本的不同部分?
- spring - 如何强制 Spring Boot 将正确的 HTTP 状态代码设置为错误响应?
- azure - 可以在无上处理 TCP 消息的 Azure 服务
- python - brew如何安装python3.6
- javascript - 未捕获的类型错误:无法读取 null 的属性“addEventListener”(多种形式)