首页 > 解决方案 > Seaborn 热图使注释文本适合单元格

问题描述

我有这个显示混淆矩阵的代码。在每个单元格中,首先显示准确度,然后在其下方显示正确预测样本数/总样本数。现在我想显示每个单元格内的所有文本。例如,第一个单元格应在精度下显示 186/208。如何在单元格内显示注释的全文?我试图减小字体大小,但没有奏效。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def cm_analysis(cm, labels, figsize=(20,15)):
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float)
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.2f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.2f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Groundtruth labels'
    cm.columns.name = 'Predicted labels'
    fig, ax = plt.subplots(figsize=figsize)
    ax.axhline(color='black')

    g =sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax, cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
    g.set_xticklabels(g.get_xticklabels(), rotation = 45)
    sns.set(font_scale=1.1)
    plt.savefig("filename.png")

normalised_confusion_matrix  = np.array(
[[186,3,0,1,2,0,3,3,7,1,2,0,0],
 [5,9,1,0,3,0,0,0,0,0,0,0,1],
 [0,0,49,3,0,0,0,0,1,0,0,0,6],
 [1,0,6,89,0,0,0,0,1,1,1,0,1],
 [3,7,0,0,50,0,0,0,6,0,1,0,0],
 [1,0,0,0,0,9,0,1,0,0,0,0,0],
 [3,0,1,0,0,0,54,0,0,0,3,0,0],
 [2,0,0,0,0,0,2,7,0,0,0,0,0],
 [3,0,0,0,2,1,2,0,53,2,4,0,0],
 [0,0,0,1,0,1,0,0,1,7,0,1,0],
 [1,1,0,0,1,0,1,0,3,0,52,0,0],
 [1,0,0,0,0,0,0,0,1,0,0,5,0],
 [0,0,11,2,0,0,0,0,0,0,0,0,26]]
)

classes = ['Assemble system','Consult sheets','Picking in front','Picking left','Put down component','Put down measuring rod','Put down screwdriver','Put down subsystem','Take component','Take measuring rod','Take screwdriver','Take subsystem','Turn sheets']

    
cm_analysis(cm= normalised_confusion_matrix, labels = classes)

在此处输入图像描述

标签: pythonseaborn

解决方案


主要问题是将annot数组创建为类型str而不是object(so, annot = np.empty_like(cm).astype(object))。拥有它的类型str会导致奇怪的错误,因为 numpy 字符串有一些内置的最大长度。(另见这篇文章。)

由于您只在 中使用一个索引cm_sum[i],因此最好不要在cm_sum = np.sum(cm, axis=1, keepdims=False)( docs ) 中“保留尺寸”。

另外,请注意,对于百分比,您需要乘以 100。(创建格式化字符串的现代方法将使用f-strings : annot[i, j] = f'{p*100:.2f}%\n{c}/{s}')。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def cm_analysis(cm, labels, figsize=(20, 15)):
    cm_sum = np.sum(cm, axis=1, keepdims=False)
    cm_perc = cm / cm_sum.astype(float)
    annot = np.empty_like(cm).astype(object)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = f'{p*100:.1f}%\n{c}/{s}'
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = f'{p*100:.1f}%\n{c}'
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Groundtruth labels'
    cm.columns.name = 'Predicted labels'
    fig, ax = plt.subplots(figsize=figsize)
    ax.axhline(color='black')

    g = sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax,
                    cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
    g.set_xticklabels(g.get_xticklabels(), rotation=45)
    sns.set(font_scale=1.1)
    plt.savefig("filename.png")

normalised_confusion_matrix = np.array(
    [[186, 3, 0, 1, 2, 0, 3, 3, 7, 1, 2, 0, 0],
     [5, 9, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1],
     [0, 0, 49, 3, 0, 0, 0, 0, 1, 0, 0, 0, 6],
     [1, 0, 6, 89, 0, 0, 0, 0, 1, 1, 1, 0, 1],
     [3, 7, 0, 0, 50, 0, 0, 0, 6, 0, 1, 0, 0],
     [1, 0, 0, 0, 0, 9, 0, 1, 0, 0, 0, 0, 0],
     [3, 0, 1, 0, 0, 0, 54, 0, 0, 0, 3, 0, 0],
     [2, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0],
     [3, 0, 0, 0, 2, 1, 2, 0, 53, 2, 4, 0, 0],
     [0, 0, 0, 1, 0, 1, 0, 0, 1, 7, 0, 1, 0],
     [1, 1, 0, 0, 1, 0, 1, 0, 3, 0, 52, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 5, 0],
     [0, 0, 11, 2, 0, 0, 0, 0, 0, 0, 0, 0, 26]]
)

classes = ['Assemble system', 'Consult sheets', 'Picking in front', 'Picking left', 'Put down component',
           'Put down measuring rod', 'Put down screwdriver', 'Put down subsystem', 'Take component',
           'Take measuring rod', 'Take screwdriver', 'Take subsystem', 'Turn sheets']

cm_analysis(cm=normalised_confusion_matrix, labels=classes)

修正热图


推荐阅读