首页 > 解决方案 > 在pythom中打印多类随机森林的混淆矩阵

问题描述

我正在使用 sklearn.RandomForestClassifier,我有 11 个课程。我的数据在数据框中,所有变量都是热编码的。类是字符串,例如“Potato”、“Tomato”、“Straberry”等。

当我尝试打印混淆矩阵时,我得到以下信息:

print(pd.crosstab(y_test, y_pred))

Error: If using all scalar values, you must pass an index

当我传递索引时:

print(pd.crosstab(y_test, y_pred, index = [0]))

Error:crosstab() got multiple values for argument 'index'

解决这个问题的最佳方法是什么?

标签: pythonpandasmachine-learningscikit-learnrandom-forest

解决方案


该错误表明您需要将参数“索引”传递给交叉表,而不是帮助您遍历列表的索引。您可以在此处找到正确的方法和更多详细信息

您还可以使用以下代码在 Sci-Kit Learn 中绘制混淆矩阵

此代码从用于混淆矩阵的训练数据中获取所有标签

label=y_train.unique()
label=np.sort(label)

此代码导入混淆矩阵并绘制它。plt.cm.Blues用于配色方案并且clf是您的分类器,请务必使用您命名的分类器进行更改。

from sklearn.metrics import plot_confusion_matrix
cm=plot_confusion_matrix(clf,X_test, y_test,labels=label,cmap=plt.cm.Blues)

推荐阅读