首页 > 解决方案 > 在类似热图的图中写入值,但对于 seaborn 中的分类变量

问题描述

我在类似热图的图中绘制了一个数据框,我想写入单元格,但不是单元格的值,但我将值与条件进行比较并告诉它是哪种错误。

例如:

import pandas as pd 
import seaborn as sns # matplotlib inline 
import random
data = []
for i in range(10):
    data.append([random.randrange(0, 11, 1) for _ in range(10)])
df = pd.DataFrame(data)
n = 10

fig, ax = plt.subplots(figsize = (12, 10)) 
cmap = ['#b3e6b3','#66cc66','#2d862d','#ffc299','#ff944d','#ff6600','#ccddff','#99bbff','#4d88ff','#0044cc','#002b80']
ax = sns.heatmap(df, cmap=cmap, linewidths = 0.005, annot = False) 
                            
plt.show()

运行此代码时我得到的是:

这是没有价值的

然后我将数据帧df与某些条件进行比较并获得另外两个数据帧,例如:

condition1 = [['Error A'] + [np.nan]*9,
            [np.nan]*6 + ['Error C'] + [np.nan]*3,
            [np.nan]*10,
            [np.nan]*7 + ['Error B'] + [np.nan]*2,
            [np.nan]*2 + ['Error D'] + [np.nan]*3 + ['Error B'] + [np.nan]*3,
            [np.nan]*10,
            [np.nan]*3 + ['Error B'] + [np.nan]*6,
            [np.nan]*7 + ['Error A'] + [np.nan]*2,
            [np.nan]*10,
            [np.nan]*10]
df_condition1 = pd.DataFrame(data = condition1)

condition2 = [[np.nan]*10,[np.nan]*10,
            [np.nan]*10,[np.nan]*7 + ['Error C'] + [np.nan]*2,
            [np.nan]*10,[np.nan]*10,[np.nan]*10,
            [np.nan]*10,
            [np.nan]*10,
            [np.nan]*10]
df_condition2 = pd.DataFrame(data = condition2)

我想要的是在热图中显示这些数据框的值,如下所示:

我想要的是

我该怎么做?

标签: pythonpandasdataframeseaborn

解决方案


您可以构建错误文本并手动注释:

c1, c2 = df_condition1.notna(), df_condition2.notna()
df_condition1,df_condition2 = df_condition1.fillna(''), df_condition2.fillna('')

errors = np.select((c1&c2, c1, c2), 
                   (df_condition1+'\n'+df_condition2, df_condition1, df_condition2),
                   '')

fig, ax = plt.subplots(figsize = (12, 10)) 
cmap = ['#b3e6b3','#66cc66','#2d862d','#ffc299','#ff944d','#ff6600','#ccddff','#99bbff','#4d88ff','#0044cc','#002b80']
ax = sns.heatmap(df, cmap=cmap, linewidths = 0.005, annot = False) 
    
for r in range(errors.shape[0]):
    for c in range(errors.shape[1]):
        ax.text(c+0.5,r+0.5, errors[r,c], 
                va='center',ha='center',
                fontweight='bold')


plt.show()

输出:

在此处输入图像描述


推荐阅读