首页 > 解决方案 > 如何在 Python 中创建混淆矩阵的图像

问题描述

我是 Python 和机器学习的新手。我正在研究多类分类(3类)。我想将混淆矩阵保存为图像。现在,sklearn.metrics.confusion_matrix()帮助我找到混淆矩阵,如:

array([[35, 0, 6],
   [0, 0, 3],
   [5, 50, 1]])

接下来,我想知道如何将这个混淆矩阵转换为图像并保存为png。

标签: pythonmachine-learningscikit-learnconfusion-matrixmulticlass-classification

解决方案


选项 1

从 获取混淆矩阵数组后sklearn.metrics,您可以使用matplotlib.pyplot.matshow()seaborn.heatmap从该数组生成混淆矩阵图。

例如

import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt

cfm = [[35, 0, 6],
       [0, 0, 3],
       [5, 50, 1]]
classes = ["0", "1", "2"]

df_cfm = pd.DataFrame(cfm, index = classes, columns = classes)
plt.figure(figsize = (10,7))
cfm_plot = sn.heatmap(df_cfm, annot=True)
cfm_plot.figure.savefig("cfm.png")

在此处输入图像描述


选项 2

您可以使用plot_confusion_matrix()fromsklearn直接从估计器(即分类器)创建混淆矩阵的图像。

例如

cfm_plot = plot_confusion_matrix(<estimator>, <X>, <Y>)
cfm_plot.savefig("cfm.png")

这两个选项都savefig()用于将结果保存为 png 文件。

参考:https ://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html


推荐阅读