首页 > 解决方案 > 我无法理解使用 argmax() 删除 OHE 后得到的混淆矩阵

问题描述

我无法解释我的混淆矩阵。我得到低于价值的错误。

ValueError:不支持多标签指示符

在阅读了许多帖子后,我意识到这个问题可能是由于预测中的OHE(一种热编码)造成的。因此,为了解决它,我按照各种帖子中的建议使用了 argmax()。下面是我的代码:

from sklearn.metrics import confusion_matrix
print(Y.shape)
print(predictions.shape)
print(Y)
print(predictions)
# print(confusion_matrix(Y, predictions))
print(confusion_matrix(Y.argmax(axis = 1), predictions.argmax(axis = 1)))

(1, 200)
(1, 200)
[[1 1 0 0 1 1 0 1 0 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 0 1 0 1 0 1 0 0 1 1 1 0
  0 1 0 1 0 1 0 0 1 0 1 0 0 0 1 0 1 1 0 0 1 0 1 0 1 0 1 0 0 1 1 1 0 0 0 1
  0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 1 1 0 0 1 1 0 1 1 1 1 0 1 0 1 1 1 1 1
  0 0 0 1 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 1 0
  0 0 1 1 0 1 1 1 1 1 1 0 0 0 1 1 1 0 1 0 1 0 1 0 0 1 1 1 1 1 0 0 1 1 1 1
  0 1 0 0 1 0 1 0 1 0 1 0 1 0 1 0 0 1 1 0]]
[[1 1 0 0 1 1 0 1 0 0 1 0 0 0 0 1 0 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 0
  0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 1 1 1 1 0 0 1
  0 1 0 1 0 0 0 1 1 0 0 0 0 1 0 0 1 1 1 0 0 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1
  0 0 0 1 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 0 0 0 1 1 0 0 0 0 0
  0 0 1 1 0 0 1 1 1 1 1 1 1 0 0 0 1 0 1 1 1 0 1 0 0 1 1 1 1 1 0 1 1 1 1 0
  0 1 0 0 1 0 1 1 1 0 1 0 1 0 1 0 0 1 1 1]]
[[1]]

从输出中可以看出,我得到[[1]]了混淆矩阵。我不知道如何解释它。我期待一个 2x2 的混淆矩阵,然后我会继续计算精度、召回率、F1 分数等,以了解我的模型的性能。请建议我做错了什么?

标签: pythonnumpymachine-learningscikit-learnneural-network

解决方案


IIUC 的问题在于输入数组的形状。你需要先把它们弄平。这是一个重现您的案例的示例:

from sklearn.metrics import confusion_matrix

Y = np.random.choice([0,1],size=(1,10))
pred = np.random.choice([0,1],size=(1,10))

由于在您的示例中两个数组都是二维的,confusion_matrix因此解释您有多标签输出,它不支持:

confusion_matrix(Y, pred)
ValueError: multilabel-indicator is not supported

您需要展平两个数组:

confusion_matrix(Y.ravel(), pred.ravel())

推荐阅读