python-3.x - 为什么CNN只预测一类
问题描述
我有一个模型需要检测植物是死是活。它只是预测一个类别,即数据不平衡,但我使用权重来应对不平衡。
我已经查看了很多关于这个问题的问题,但似乎没有一个有效,显然这个问题是在过度拟合时发生的,所以我使用了 dropout。但该模型仍然只预测一个类别。
继承人的模型:
model=Sequential()
# Convolutional layer / input layer
model.add(Conv2D(60, 5,5, activation='relu', input_shape=np.shape(X[1])))
model.add(MaxPooling2D(pool_size=(3,3)))
model.add(Dropout(0.8))
model.add(Flatten())
model.add(Dropout(0.7))
model.add(Dense(130, activation='relu'))
model.add(Dropout(0.6))
# Output layer
model.add(Dense(2, activation='softmax'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(X, y, epochs=6, batch_size=32, class_weight=class_weight, validation_data=(X_test, y_test))
通常它应该预测两个类别:1:健康植物和 0:不健康植物
解决方案
由于您的问题是二进制分类并且您的输出维度为 2,因此您应该将激活更改为 softmax。
model.add(Dense(2, activation='softmax'))
但是,如果您想保留 sigmoid,只需将输出层单位更改为 1,这样您将输出您的输入是仅有一个单位的两个类之一的可能性有多大。
model.add(Dense(1, activation='sigmoid'))
推荐阅读
- python - 根据列值的条件组合列索引
- android - 以编程方式比较(简单)移动设备上的绘图模式
- c++ - 如何使用 Halsted's Metric 计算给定代码的大小?
- redirect - IIS URL 重写 - 页面转发
- c++ - 不同标头中使用的设置结构
- ssh - 远程 - VSCode 中的 SSH 发出“进程试图写入不存在的管道”
- virtual-machine - 无法在虚拟机中启用集成网络摄像头
- swift - 渲染自定义注释时 MKMapView 冻结
- snowflake-schema - 雪花查询缓慢
- java - 如何在不使用 iText7 加密的情况下创建不可编辑的 PDF