首页 > 解决方案 > 混淆矩阵Python中数字的位置不合适

问题描述

我有这样的代码:

plt.figure(figsize=(8,5))
confusion_matrix = pd.crosstab(y_test, predictions, rownames=["Observed"], colnames=["Anticipated"])
sns.heatmap(confusion_matrix, annot=True, fmt= "d")
plt.show()

然而,每个方格内的数字并不在每个方格的中心,如下所示。如何更改此代码以使数字位于每个正方形的中心位置? 在此处输入图像描述

标签: pythonpandasseaborn

解决方案


datascience.stackexchange.com 上的这篇文章所示,matplotlib 3.1.1坏了sns.heatmap()。那里的答案建议降级到matplotlib 3.1.0。但是我已经在我的机器上安装了 3.1.2 并且它正在工作。因此,您现在也许可以升级。

y_test = np.array(
    [
        "foo", "foo", "foo", "foo",
        "bar", "bar", "bar", "bar",
        "foo", "foo", "foo"
    ], dtype=object
)
predictions = np.array(
    [
        "one", "one", "one", "two",
        "one", "one", "one", "two",
        "two", "two", "one"
    ], dtype=object
)

plt.figure(figsize=(8, 5))
confusion_matrix = pd.crosstab(
    y_test, predictions, rownames=["Observed"], colnames=["Anticipated"]
)
sns.heatmap(confusion_matrix, annot=True, fmt="d")
plt.show()

在此处输入图像描述

这是pip list使用相关软件包的结果。

Package            Version   Location
------------------ --------- -------------------------------------
matplotlib         3.1.2
pandas             0.25.1
seaborn            0.9.0

推荐阅读