python - 即使在四舍五入后也无法获得分类报告
问题描述
我正在尝试使用 keras 对我的数据集进行分类,但ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets
出现错误。中的值y_pred
如下
array([[2.95522604e-02, 9.70325887e-01, 3.20542094e-05, ...,
1.74383260e-07, 1.98587145e-07, 9.88743452e-08],
[3.25102806e-01, 6.68996394e-01, 1.65001326e-03, ...,
5.84201662e-05, 5.91963508e-05, 4.68929684e-05],
[8.87618303e-01, 1.12024814e-01, 1.22764613e-04, ...,
1.44616331e-06, 1.33618846e-06, 1.68983024e-06],
...,
[3.09438616e-01, 6.83520675e-01, 1.94711238e-03, ...,
7.57295784e-05, 7.51852640e-05, 5.94857411e-05],
[6.73729360e-01, 3.21534157e-01, 1.41171378e-03, ...,
4.93246625e-05, 4.61974196e-05, 4.73670734e-05],
[1.33120596e-01, 8.64127636e-01, 7.41749362e-04, ...,
1.87505502e-05, 1.95825924e-05, 1.34223355e-05]], dtype=float32)
我正在将它们四舍五入,如本问题中所述,因为y_test
值是
array([1, 0, 0, ..., 0, 1, 1])
y_pred
与y_pred = y_pred.round().astype(int)
我四舍五入后
array([[0, 1, 0, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
...,
[0, 1, 0, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
[0, 1, 0, ..., 0, 0, 0]])
即使在此之后,当我尝试使用print(metrics.classification_report(y_test, y_pred))
我得到与上述相同的错误来获取分类报告时。有人可以帮我看看我在这里做错了什么吗?谢谢
解决方案
scikit -learn 文档指出y_pred
输入必须是1d array-like。你需要 argmax 你的 logits。
import numpy as np
import tensorflow as tf
from sklearn.metrics import classification_report
y_pred = tf.math.abs(tf.random.normal([32, 2])).numpy()
y_test = tf.random.uniform([32, 1], minval=0, maxval=2, dtype=tf.int32).numpy()
# this will explode
print(classification_report(y_test, y_pred))
# ValueError: Classification metrics can't handle a mix of binary and
# continuous-multioutput targets
# get predicted indices
y_pred = np.argmax(y_pred, 1)
# try again
print(classification_report(y_test, y_pred))
# precision recall f1-score support
#
# 0 0.41 0.50 0.45 14
# 1 0.53 0.44 0.48 18
#
# accuracy 0.47 32
# macro avg 0.47 0.47 0.47 32
# weighted avg 0.48 0.47 0.47 32
推荐阅读
- firebase - 当我尝试将图像上传到 Firebase 存储时,putFile 在 kotlin 中不起作用
- database - 在 PostgreSQL 中的 UPSERT 上返回 PK id 时出错,阻止 ON CONFLICT 运行
- javascript - 使用 .join 映射数组时遇到问题
- json - 解析 json 未确定数量的嵌套与 hive
- matrix - 如何获取矩阵中每一列和每一行的最小值的索引?在 Python3
- rust - 无法实现 Ord 时如何使用 BinaryHeap?
- sqlite - 如何将android的db sqlite与使用nodejs制作的rest服务同步?
- android - 为什么当我返回活动时 ViewModel 不提供数据?
- electron - Electron.js 多个加速器
- go - 如何解释这些时间戳