python - Seaborn 热图使注释文本适合单元格
问题描述
我有这个显示混淆矩阵的代码。在每个单元格中,首先显示准确度,然后在其下方显示正确预测样本数/总样本数。现在我想显示每个单元格内的所有文本。例如,第一个单元格应在精度下显示 186/208。如何在单元格内显示注释的全文?我试图减小字体大小,但没有奏效。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def cm_analysis(cm, labels, figsize=(20,15)):
cm_sum = np.sum(cm, axis=1, keepdims=True)
cm_perc = cm / cm_sum.astype(float)
annot = np.empty_like(cm).astype(str)
nrows, ncols = cm.shape
for i in range(nrows):
for j in range(ncols):
c = cm[i, j]
p = cm_perc[i, j]
if i == j:
s = cm_sum[i]
annot[i, j] = '%.2f%%\n%d/%d' % (p, c, s)
elif c == 0:
annot[i, j] = ''
else:
annot[i, j] = '%.2f%%\n%d' % (p, c)
cm = pd.DataFrame(cm, index=labels, columns=labels)
cm.index.name = 'Groundtruth labels'
cm.columns.name = 'Predicted labels'
fig, ax = plt.subplots(figsize=figsize)
ax.axhline(color='black')
g =sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax, cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
g.set_xticklabels(g.get_xticklabels(), rotation = 45)
sns.set(font_scale=1.1)
plt.savefig("filename.png")
normalised_confusion_matrix = np.array(
[[186,3,0,1,2,0,3,3,7,1,2,0,0],
[5,9,1,0,3,0,0,0,0,0,0,0,1],
[0,0,49,3,0,0,0,0,1,0,0,0,6],
[1,0,6,89,0,0,0,0,1,1,1,0,1],
[3,7,0,0,50,0,0,0,6,0,1,0,0],
[1,0,0,0,0,9,0,1,0,0,0,0,0],
[3,0,1,0,0,0,54,0,0,0,3,0,0],
[2,0,0,0,0,0,2,7,0,0,0,0,0],
[3,0,0,0,2,1,2,0,53,2,4,0,0],
[0,0,0,1,0,1,0,0,1,7,0,1,0],
[1,1,0,0,1,0,1,0,3,0,52,0,0],
[1,0,0,0,0,0,0,0,1,0,0,5,0],
[0,0,11,2,0,0,0,0,0,0,0,0,26]]
)
classes = ['Assemble system','Consult sheets','Picking in front','Picking left','Put down component','Put down measuring rod','Put down screwdriver','Put down subsystem','Take component','Take measuring rod','Take screwdriver','Take subsystem','Turn sheets']
cm_analysis(cm= normalised_confusion_matrix, labels = classes)
解决方案
主要问题是将annot
数组创建为类型str
而不是object
(so, annot = np.empty_like(cm).astype(object)
)。拥有它的类型str
会导致奇怪的错误,因为 numpy 字符串有一些内置的最大长度。(另见这篇文章。)
由于您只在 中使用一个索引cm_sum[i]
,因此最好不要在cm_sum = np.sum(cm, axis=1, keepdims=False)
( docs ) 中“保留尺寸”。
另外,请注意,对于百分比,您需要乘以 100。(创建格式化字符串的现代方法将使用f-strings : annot[i, j] = f'{p*100:.2f}%\n{c}/{s}'
)。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def cm_analysis(cm, labels, figsize=(20, 15)):
cm_sum = np.sum(cm, axis=1, keepdims=False)
cm_perc = cm / cm_sum.astype(float)
annot = np.empty_like(cm).astype(object)
nrows, ncols = cm.shape
for i in range(nrows):
for j in range(ncols):
c = cm[i, j]
p = cm_perc[i, j]
if i == j:
s = cm_sum[i]
annot[i, j] = f'{p*100:.1f}%\n{c}/{s}'
elif c == 0:
annot[i, j] = ''
else:
annot[i, j] = f'{p*100:.1f}%\n{c}'
cm = pd.DataFrame(cm, index=labels, columns=labels)
cm.index.name = 'Groundtruth labels'
cm.columns.name = 'Predicted labels'
fig, ax = plt.subplots(figsize=figsize)
ax.axhline(color='black')
g = sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax,
cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
g.set_xticklabels(g.get_xticklabels(), rotation=45)
sns.set(font_scale=1.1)
plt.savefig("filename.png")
normalised_confusion_matrix = np.array(
[[186, 3, 0, 1, 2, 0, 3, 3, 7, 1, 2, 0, 0],
[5, 9, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 49, 3, 0, 0, 0, 0, 1, 0, 0, 0, 6],
[1, 0, 6, 89, 0, 0, 0, 0, 1, 1, 1, 0, 1],
[3, 7, 0, 0, 50, 0, 0, 0, 6, 0, 1, 0, 0],
[1, 0, 0, 0, 0, 9, 0, 1, 0, 0, 0, 0, 0],
[3, 0, 1, 0, 0, 0, 54, 0, 0, 0, 3, 0, 0],
[2, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0],
[3, 0, 0, 0, 2, 1, 2, 0, 53, 2, 4, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 0, 1, 7, 0, 1, 0],
[1, 1, 0, 0, 1, 0, 1, 0, 3, 0, 52, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 5, 0],
[0, 0, 11, 2, 0, 0, 0, 0, 0, 0, 0, 0, 26]]
)
classes = ['Assemble system', 'Consult sheets', 'Picking in front', 'Picking left', 'Put down component',
'Put down measuring rod', 'Put down screwdriver', 'Put down subsystem', 'Take component',
'Take measuring rod', 'Take screwdriver', 'Take subsystem', 'Turn sheets']
cm_analysis(cm=normalised_confusion_matrix, labels=classes)
推荐阅读
- cordova - 无法从我的 Ionic 3.0 Windows 模拟器在我的服务器上进行 REST API 调用
- java - Spring Data - Redis 存储库不使用 pageRequest 排序
- python - python在RTF中嵌入图像
- java - arraylist 和扫描仪的正确语法
- ruby-on-rails - OAuth1 reject_token 401 未授权
- signalr - SignalR 获取连接 ID
- elasticsearch - Elasticsearch _search 查询总是在每个索引上运行
- utf-8 - IBM 集成总线:无法转换十六进制代码:c280 到 437
- excel - 引用工作表时需要运行时“424”错误对象
- python - 如何将函数应用于一系列字典?