python - CNN 模型不适用于 4 个物种,但适用于 2 个物种
问题描述
我在两个类上尝试了 CNN 模型并得到了 80%,但是当我用 4 个类尝试相同的模型时,我得到了非常糟糕的结果。请帮忙的原因是什么。我使用的CNN模型是:
model= Sequential()
model.add(Conv2D(64,(3,3),input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64,(3,3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64,(3,3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
#opt = SGD( lr=0.01)
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
history = model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples//batch_size,
epochs=epochs,
validation_data = validation_generator,
validation_steps = validation_generator.samples // batch_size,
)
2个类的结果是这样的,我失去了它的实际结果:
Epoch 29/35
46/46 [==============================] - 188s 4s/step - loss: 0.6511 - accuracy: 0.5880 - val_loss: 0.7534 - val_accuracy: 0.5175
4个类的结果是:
46/46 [==============================] - 367s 8s/step - loss: -10550614391401.7266 - accuracy: 0.2541 - val_loss: -15023441182720.0000 - val_accuracy: 0.2354
解决方案
输出层使用sigmoid
激活函数,只能用于二进制分类问题。
对于两个以上的类,softmax
在它应该有num_of_classes
节点之前使用激活函数和密集层。
model.add(Dense(numclasses)) # numclasses = 4 in your case
model.add(Activation('softmax'))
此外,损失应该从更改binary_crossentropy
为categorical_crossentropy
(这是在您的案例中显示的奇怪损失的主要原因)。
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
注意: categorical_crossentropy
需要one-hot
向量。如果您拥有的标签只是一维数组而不是单热向量,请使用sparse_categorical_crossentropy
推荐阅读
- python - 从两个图像中找到裁剪参数
- python-3.x - Pandas 日期从“2020 年 10 月 23 日;2020 年 8 月 27 日”转换为“2020 年 10 月 23 日;2020 年 8 月 27 日”
- kotlin - kotlinx-serialization - 为什么从基类继承的默认值总是被编码?
- c - strtok 的大小 1 无效
- android - 如何使用 Flutter 加快 Webview 中的自动登录?
- java - 致命信号 11 (SIGSEGV),代码 1,tid 1597 中的故障地址 0x0(电话服务器)
- ios - “任何”类型的值没有成员“标题”
- tomcat9 - 无法使用可流动的战争启动 tomcat 9-PUBLIC.ACT_DE_DATABASECHANGELOGLOCK 错误
- node.js - Firebase CORS - 未处理的错误类型错误:无法读取未定义的属性来源
- css - 我想更改 div 的样式取决于是否检查了该 div 内的输入?