python - 如何在 Python 中创建混淆矩阵的图像
问题描述
我是 Python 和机器学习的新手。我正在研究多类分类(3类)。我想将混淆矩阵保存为图像。现在,sklearn.metrics.confusion_matrix()
帮助我找到混淆矩阵,如:
array([[35, 0, 6],
[0, 0, 3],
[5, 50, 1]])
接下来,我想知道如何将这个混淆矩阵转换为图像并保存为png。
解决方案
选项 1:
从 获取混淆矩阵数组后sklearn.metrics
,您可以使用matplotlib.pyplot.matshow()
或seaborn.heatmap
从该数组生成混淆矩阵图。
例如
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
cfm = [[35, 0, 6],
[0, 0, 3],
[5, 50, 1]]
classes = ["0", "1", "2"]
df_cfm = pd.DataFrame(cfm, index = classes, columns = classes)
plt.figure(figsize = (10,7))
cfm_plot = sn.heatmap(df_cfm, annot=True)
cfm_plot.figure.savefig("cfm.png")
选项 2:
您可以使用plot_confusion_matrix()
fromsklearn
直接从估计器(即分类器)创建混淆矩阵的图像。
例如
cfm_plot = plot_confusion_matrix(<estimator>, <X>, <Y>)
cfm_plot.savefig("cfm.png")
这两个选项都savefig()
用于将结果保存为 png 文件。
参考:https ://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html