python - 使用 Keras 对 3 个类进行图像分类仅返回一个值而不是 1 X 3 数组
问题描述
我正在训练一个用于多标签图像分类的 Keras 模型,即 3 个类别,即洪水、野火、风暴。
但我得到的只是[[1.]]
而不是类似的东西[0 0 1]
。所以如果第三位是一,那就是一场风暴。但我不知道为什么它只返回一个值[[1.]]
。
# # Importing the Keras libraries and packages
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
import numpy as np
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def create_model() :
# Initialising the CNN
classifier = Sequential()
# Step 1 - Convolution
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Adding a second convolutional layer
classifier.add(Conv2D(32, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 3 - Flattening
classifier.add(Flatten())
# Step 4 - Full connection
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 1, activation = 'sigmoid'))
return classifier
def train_save_model():
classifier = create_model()
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
# Part 2 - Fitting the CNN to the images
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('training_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')
test_set = test_datagen.flow_from_directory('validation_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')
classifier.fit_generator(training_set,
steps_per_epoch = 1407,
epochs = 1,
validation_data = test_set,
validation_steps = 100)
classifier.save_weights("model.h5")
# Part 3 - Making new predictions
def test_model():
classifier = create_model()
classifier.load_weights("model.h5")
test_image = image.load_img('validation_set/tornado/110.jpg', target_size = (64, 64))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis = 0)
# print(test_image)
result = classifier.predict(test_image)
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True,
)
training_set = train_datagen.flow_from_directory('training_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'binary')
training_set.class_indices
# print(training_set.class_indices)
print(result)
train_save_model()
test_model()
result = classifier.predict(test_image)
我尝试打印此result
变量并得到[[1.]]
. 我完全无法理解这是怎么回事。
解决方案
如果你有 N 个标签,那么最后一层(即 sigmoid 分类器层)也必须有 N 个神经元,每个类一个:
classifier.add(Dense(units=3, activation='sigmoid'))
然后,对于每个输入样本,模型的输出将是对应于三个标签的 3 个数字。
更新:class_mode = 'binary'
从所有flow_from_directory
呼叫中删除。那是因为您正在多个类之间进行分类,因此生成的标签应该是分类的(默认行为)或稀疏的(即class_mode='sparse'
)。此外,在阅读了您的代码的相关部分之后,您似乎正在做多类分类,而不是多标签分类。阅读此答案以确保并找出您应该使用哪个激活和损失函数。
推荐阅读
- python - 从函数中获取 Azure 托管标识
- node.js - 安装最新的 Node 版本后无法升级到最新版本的 npm
- python - 我想解压缩部分 zip.001、zip.002 的文件。PYTHON
- python - 从时间序列中提取模式
- project-reactor - Project Reactor 的 Flux 能否一一处理消息
- java - WAR 文件在不同机器上编译时不起作用
- javascript - 在 Cytoscape.js 中使节点内的图像可点击
- ios - Swift Uiview 背景动画不起作用
- java - 无法为最终变量“weekTotalDaylyMinutes”赋值
- mysql - 在多台服务器上使用 mysql-server docker 容器创建 MySQL 集群