python-3.x - 模型只预测所有测试图像的一类
问题描述
0
我是 tensorflow 的新手(我使用的是 2.1.0 版)并且遇到了问题。我想教我的模型对图像进行分类。我有4节课。我使用 CNN,在训练模型时,我的准确率为 25%(整组的 1/4),我注意到在测试类中总是预测到第 1 类。我尝试了一切,放大和缩小,尝试了彩色和黑白照片。我的套装每个班级有 3 400 张 32x32 的照片。请帮忙。
我的代码如下:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shutil
import plotly.graph_objects as go
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.applications import VGG19
from tensorflow.keras import losses
np.set_printoptions(precision=6, suppress=True)
image_size = 32
base_dir = 'resize\\train'
raw_no_of_files = {}
classes = ['car', 'motorcycle', 'truck', 'building']
for dir in classes:
raw_no_of_files[dir] = len(os.listdir(os.path.join(base_dir, dir)))
print(raw_no_of_files.items())
data_dir = 'res'
if not os.path.exists(data_dir):
os.mkdir(data_dir)
train_dir = os.path.join(data_dir, 'train')
valid_dir = os.path.join(data_dir, 'valid')
test_dir = os.path.join(data_dir, 'test')
train_car_dir = os.path.join(train_dir, 'car')
train_motorcycle_dir = os.path.join(train_dir, 'motorcycle')
train_truck_dir = os.path.join(train_dir, 'truck')
train_building_dir = os.path.join(train_dir, 'building')
valid_car_dir = os.path.join(valid_dir, 'car')
valid_motorcycle_dir = os.path.join(valid_dir, 'motorcycle')
valid_truck_dir = os.path.join(valid_dir, 'truck')
valid_building_dir = os.path.join(valid_dir, 'building')
test_car_dir = os.path.join(test_dir, 'car')
test_motorcycle_dir = os.path.join(test_dir, 'motorcycle')
test_truck_dir = os.path.join(test_dir, 'truck')
test_building_dir = os.path.join(test_dir, 'building')
for directory in (train_dir, valid_dir, test_dir):
if not os.path.exists(directory):
os.mkdir(directory)
dirs = [train_car_dir, train_motorcycle_dir, train_truck_dir, train_building_dir,
valid_car_dir, valid_motorcycle_dir, valid_truck_dir, valid_building_dir,
test_car_dir, test_motorcycle_dir, test_truck_dir, test_building_dir]
for dir in dirs:
if not os.path.exists(dir):
os.mkdir(dir)
car_fnames = os.listdir(os.path.join(base_dir, 'car'))
motorcycle_fnames = os.listdir(os.path.join(base_dir, 'motorcycle'))
truck_fnames = os.listdir(os.path.join(base_dir, 'truck'))
building_fnames = os.listdir(os.path.join(base_dir, 'building'))
size = min(len(car_fnames), len(motorcycle_fnames), len(truck_fnames), len(building_fnames))
train_size = int(np.floor(0.7 * size))
valid_size = int(np.floor(0.2 * size))
test_size = size - train_size - valid_size
train_idx = train_size
valid_idx = train_size + valid_size
test_idx = train_size + valid_size + test_size
for i, fname in enumerate(car_fnames):
if i <= train_idx:
src = os.path.join(base_dir, 'car', fname)
dst = os.path.join(train_car_dir, fname)
shutil.copyfile(src, dst)
elif train_idx < i <= valid_idx:
src = os.path.join(base_dir, 'car', fname)
dst = os.path.join(valid_car_dir, fname)
shutil.copyfile(src, dst)
elif valid_idx < i < test_idx:
src = os.path.join(base_dir, 'car', fname)
dst = os.path.join(test_car_dir, fname)
shutil.copyfile(src, dst)
for i, fname in enumerate(motorcycle_fnames):
if i <= train_idx:
src = os.path.join(base_dir, 'motorcycle', fname)
dst = os.path.join(train_motorcycle_dir, fname)
shutil.copyfile(src, dst)
elif train_idx < i <= valid_idx:
src = os.path.join(base_dir, 'motorcycle', fname)
dst = os.path.join(valid_motorcycle_dir, fname)
shutil.copyfile(src, dst)
elif valid_idx < i < test_idx:
src = os.path.join(base_dir, 'motorcycle', fname)
dst = os.path.join(test_motorcycle_dir, fname)
shutil.copyfile(src, dst)
for i, fname in enumerate(truck_fnames):
if i <= train_idx:
src = os.path.join(base_dir, 'truck', fname)
dst = os.path.join(train_truck_dir, fname)
shutil.copyfile(src, dst)
elif train_idx < i <= valid_idx:
src = os.path.join(base_dir, 'truck', fname)
dst = os.path.join(valid_truck_dir, fname)
shutil.copyfile(src, dst)
elif valid_idx < i < test_idx:
src = os.path.join(base_dir, 'truck', fname)
dst = os.path.join(test_truck_dir, fname)
shutil.copyfile(src, dst)
for i, fname in enumerate(building_fnames):
if i <= train_idx:
src = os.path.join(base_dir, 'building', fname)
dst = os.path.join(train_building_dir, fname)
shutil.copyfile(src, dst)
elif train_idx < i <= valid_idx:
src = os.path.join(base_dir, 'building', fname)
dst = os.path.join(valid_building_dir, fname)
shutil.copyfile(src, dst)
elif valid_idx < i < test_idx:
src = os.path.join(base_dir, 'building', fname)
dst = os.path.join(test_building_dir, fname)
shutil.copyfile(src, dst)
print('samochód - zbiór treningowy', len(os.listdir(train_car_dir)))
print('samochód - zbiór walidacyjny', len(os.listdir(valid_car_dir)))
print('samochód - zbiór testowy', len(os.listdir(test_car_dir)))
print('motocykl - zbiór treningowy', len(os.listdir(train_motorcycle_dir)))
print('motocykl - zbiór walidacyjny', len(os.listdir(valid_motorcycle_dir)))
print('motocykl - zbiór testowy', len(os.listdir(test_motorcycle_dir)))
print('tir - zbiór treningowy', len(os.listdir(train_truck_dir)))
print('tir - zbiór walidacyjny', len(os.listdir(valid_truck_dir)))
print('tir - zbiór testowy', len(os.listdir(test_truck_dir)))
print('budowlane - zbiór treningowy', len(os.listdir(train_building_dir)))
print('budowlane - zbiór walidacyjny', len(os.listdir(valid_building_dir)))
print('budowlane - zbiór testowy', len(os.listdir(test_building_dir)))
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
valid_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory(directory=train_dir,
target_size=(image_size, image_size),
batch_size=32,
class_mode='categorical')
valid_generator = valid_datagen.flow_from_directory(directory=valid_dir,
target_size=(image_size, image_size),
batch_size=32,
class_mode='categorical')
batch_size = 32
steps_per_epoch = train_size // batch_size
validation_steps = valid_size // batch_size
model = Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.summary()
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(4))
model.summary()
model.compile(optimizer=optimizers.RMSprop(lr=1e-5),
loss='categorical_crossentropy',
metrics=['acc'])
model.summary()
history = model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
epochs=30, # 100
validation_data=valid_generator,
validation_steps=validation_steps)
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(image_size, image_size),
batch_size=1,
class_mode='categorical',
shuffle=False
)
model.save('cars_' + str(size) + '.h5')
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, label='Dokładność trenowania')
plt.plot(epochs, val_acc, label='Dokładność walidacji')
plt.xlabel('Epoka')
plt.ylabel('Dokładność')
plt.title('Dokładność trenowania i walidacji')
plt.legend()
plt.figure()
plt.plot(epochs, loss, label='Strata trenowania')
plt.plot(epochs, val_loss, label='Strata walidacji')
plt.title('Strata trenowania i walidacji')
plt.legend()
plt.show()
y_prob = model.predict_generator(test_generator, test_generator.samples)
print(type(y_prob))
print(y_prob)
#
#
y_pred = np.argmax(y_prob, axis=1)
print(type(y_pred))
print(y_pred)
#
predictions = pd.DataFrame({'class': y_pred})
print(predictions)
y_true = test_generator.classes
print(y_true)
y_pred = predictions['class'].values
print(y_pred)
print(test_generator.class_indices)
classes = list(test_generator.class_indices.keys())
print(classes)
cm = confusion_matrix(y_true, y_pred)
print(cm)
#
#
def plot_confusion_matrix(cm):
cm = cm[::-1]
cm = pd.DataFrame(cm, columns=classes, index=classes[::-1])
fig = ff.create_annotated_heatmap(z=cm.values, x=list(cm.columns), y=list(cm.index), colorscale='ice', showscale=True, reversescale=True)
fig.update_layout(width=500, height=500, title='Confusion Matrix', font_size=16)
fig.show()
import plotly.figure_factory as ff
plot_confusion_matrix(cm)
#
print(classification_report(y_true, y_pred, target_names=test_generator.class_indices.keys()))
errors = pd.DataFrame({'y_true': y_true, 'y_pred': y_pred}, index=test_generator.filenames)
print(errors)
errors['is_incorrect'] = (errors['y_true'] != errors['y_pred']) * 1
print(errors)
print(errors[errors['is_incorrect'] == 1].index)
解决方案
推荐阅读
- javascript - Javascript中有没有一种方法可以根据用户的输入过滤产品列表并删除该项目而不将其从我的数据库中删除?
- python - “无效的日期时间格式”使用 .read_sql() 读取 MS Access 日期/时间值
- javascript - JavaScript 对象:无法向对象添加新属性
- python - 如何将 unicode 字符串写入 json?
- python - Pb:'不再支持将列表喜欢传递给带有任何缺失标签的 .loc 或 []
- javascript - 如何让 PeerJS 视频通话与 Heroku 一起工作?
- c++ - 将多个独立的 Rust 库链接到 C++ 二进制文件时需要注意什么负面后果?
- amazon-web-services - 从 Alexa 开发人员控制台到 AWS 托管的 DynamoDB 的 CRUD
- css - 绝对位置元素不能溢出其他弹性项目
- oauth - 您是否应该序列化调用以刷新 oauth 令牌