python - 对于 tensorflow.keras.models.Sequential.predict 中的多类多标签问题,是否可能只得到 0 和 1?
问题描述
假设我有 3 个类,每个样本都可以属于这些类中的任何一个。标签看起来像这样。
[
[1 0 0]
[0 1 0]
[0 0 1]
[1 1 0]
[1 0 1]
[0 1 1]
[1 1 1]
]
我将输出设置为 Dense(3, activation="sigmoid"),并使用 optimizer="adam", loss="binary_crossentropy" 进行编译。根据 Keras 的输出,我猜损失为 0.05,准确率为 0.98。
我想如果我使用 sigmoid 和 binary_crossentropy,我只会得到 1 或 0 的预测值。然而,model.predict(training-features) 给了我 1 到 0 之间的值,比如 0.0026。我已经尝试了 categorical_crossentropy 和 binary_crossentropy 与 sigmoid 和 softmax 之间的所有 4 种组合。Model.predict 总是返回一个介于 0 和 1 之间的值,形状为 n_samples by n_classes。在上面的示例中它将是 7x3。
然后我将值剪裁为 0.5,如下所示并检查了 accuracy_score(training_labels,preds)。分数下降到0.1。
preds[preds>=0.5] = 1
preds[preds<0.5] = 0
如果有人能给我一些关于如何解决这个问题的指导,我将不胜感激。
谢谢!
解决方案
根据您的描述,这是一个多标签分类问题,因此您应该将sigmoid
其用作最后一层的激活函数和binary_crossentropy
损失函数。那是因为我们认为每个标签的分类独立于所有其他标签。因此,在这种情况下使用softmax
orcategorical_crossentropy
是错误的。
Keras 报告的准确度与使用sklearn.metrics.accuracy_score()
函数计算的准确度之间的差异不是由于四舍五入造成的;实际上,Keras 进行了与计算精度相同的舍入(或裁剪)。相反,差异是由于accuracy_score
多标签分类模式下的函数仅在该样本的所有真实标签和预测标签相互匹配时才认为该样本被正确分类。这已在文档中明确说明:
在多标签分类中,此函数计算子集精度:为样本预测的标签集必须与 y_true 中的相应标签集完全匹配。
然而,在 Kerasbinary_accuracy
函数中报告正确分类标签的平均分数(即部分匹配是可以接受的)。为了更好地理解这一点,请考虑以下示例:
True labels | Predictions | Keras binary acc | accuracy_score
-----------------------------------------------------------------
[1 0 0] | [1 0 1] | 2 correct = 0.66 | not match = 0.00
[0 1 1] | [0 1 1] | 3 correct = 1.00 | match = 1.00
[1 0 1] | [0 0 1] | 2 correct = 0.66 | not match = 0.00
=================================================================
average reported acc | 0.77 | 0.33
推荐阅读
- javascript - 如何通过 ID 从 Apollo 缓存中读取嵌套对象?
- reactjs - 每当在 reactjs 中选中复选框时清除文本输入
- ios - 在 SwiftUI 中关闭地图视图 (MKMapView)
- android - 如何在 Jacoco 报告生成中公开“sourceDirectories”?
- unity3d - 子弹总是从静态位置发射
- c# - XAMARIN 上是否有优雅的 IOS 数字按钮
- java - Multiply an int by 30, 31, 32 - are these really optimized by the compiler? (effective java says so)
- php - 如何在同一个刀片模板(视图)上查看两个表数据?
- postman - SQLSTATE [23000]:违反完整性约束:1048 列“tokenable_id”不能为空
- c++ - .h 和 .cpp 文件分离时出错,但仅使用 .h 文件时没有错误。我究竟做错了什么?