python - 从函数调用时显示混淆矩阵
问题描述
我有一个从 scikit learn 导入随机森林分类器的函数,我用数据拟合它,最后我想显示准确度、kappa 和混淆矩阵。除了打印混淆矩阵之外的所有工作。我没有收到任何错误,但没有打印混淆矩阵。
我已经尝试调用print(cm)
并且它可以工作,但它不会以通常的 pandas 数据框样式打印,这正是我正在寻找的。
这是代码
def rf_clf(X, y, test_size = 0.3, random_state = 42):
"""This function splits the data into train and test and fits it in a random forest classifier
to the data provided, analysing its errors (Accuracy and Kappa). Also as this is classification,
the function will output a confusion matrix"""
#Split data in train and test, as well as predictors (X) and targets, (y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)
#import random forest classifier
base_model = RandomForestClassifier(random_state=random_state)
#Train the model
base_model.fit(X_train,y_train)
#make predictions on test set
y_pred=base_model.predict(X_test)
#Print Accuracy and Kappa
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
print("Kappa:",metrics.cohen_kappa_score(y_test, y_pred))
#create confusion matrix
labs = [y_test[i][0] for i in range(len(y_test))]
cm = pd.DataFrame(confusion_matrix(labs, y_pred))
cm #here is the issue. Kinda works with print(cm)
解决方案
一开始从 sklearn 导入指标。
from sklearn import metrics
当你想显示混淆矩阵时使用它。
# Get and show confussion matrix cm = metrics.confusion_matrix(y_test, y_pred) print(cm)
有了这个,您应该在原始文本中查看混淆矩阵。
如果您想用颜色显示混淆矩阵,请以其他方式进行:
进口
from sklearn.metrics import confusion_matrix import pandas as pd import seaborn as sns; sns.set()
以这种方式使用它:
cm = confusion_matrix(y_test, y_pred) cmat_df = pd.DataFrame(cm, index=class_names, columns=class_names) ax = sns.heatmap(cmat_df, square=True, annot=True, cbar=False) ax.set_xlabel('Predicción') ax.set_ylabel('Real')`
希望最好的!
推荐阅读
- hyperledger-fabric - 一个节点可以在两个不同的结构网络中吗?
- c++ - 我应该如何使用条件变量解决哲学家就餐问题?
- java - 如何在 gmail 登录表单中获取包含错误消息的 Web 元素
- python-3.x - Pygame ValueError: 无效颜色参数问题
- java - 如何检测哪个片段处于活动状态?
- fody - 仅选择自动实现属性的选项?
- eclipse-scout - SAAS 应用程序,每个公司的子域
- android - Android-Socket 不发送自定义对象
- arrays - React native 在数组中给对象一个名称
- python - python队列是否使用GIL?