首页 > 解决方案 > 包含大量信息的 Matplotlib/Seaborn 箱线图

问题描述

我需要绘制 3 个不同数据集的实验结果,其中 11 个模型超过 4 个指标,每个指标有 100 个数据点。我知道将所有这些信息整合到一个情节中非常困难,而且可能更难以阅读。目前,我有 12 个图,每个实验/指标 (3*4) 一个,所有 11 个模型的结果都在一个图中。我试图减少这 4 个图:每个指标一个,11 个模型和 3 个实验在同一个图中。

更具体地说,我使用的指标是敏感性、特异性、PPV 和 AUC。有 3 个不同的实验:notes_commonfull_common_vitalfull_common_all。我有 11 个模型。目前,这是我必须准备好数据以进行绘图的代码:

prefix = 'notes_common_vital'

bams = pickle.load(open(workdir/f'{prefix}_bams.pkl', 'rb'))
for k in bams.keys():
  bams[k.upper()] = bams.pop(k)
bams['AVG-ALL'] = bams.pop('AVG-LR-RF-GBM')
bams['MAX-ALL'] = bams.pop('MAX-LR-RF-GBM')

itr = iter(bams.keys())
bams.keys()

metrics = {}

for md in itr:
  df = pd.DataFrame()
  for k, m in bams[md].yield_metrics():
    df[k] = m
  df['model'] = md
  cols = list(df.columns)
  cols = [cols[-1]] + cols[:-1]
  df = df[cols]
  metrics[md] = df

plot_df = pd.concat(metrics.values())

bams只是我创建的自定义类的一个对象,用于保存超过 100 次迭代的二进制平均指标。

plot_df.shape
(1100, 5)

plot_df.columns
Index(['model', 'Sensitivity', 'Specificity', 'PPV', 'AUC'], dtype='object')

plot_df.head()

    model   Sensitivity Specificity PPV AUC
0   LR  0.782575    0.607646    0.389910    0.763138
1   LR  0.810860    0.537603    0.362753    0.752767
2   LR  0.823888    0.598635    0.341402    0.784208
3   LR  0.810928    0.617947    0.356734    0.782843
4   LR  0.833948    0.553218    0.333702    0.765500

对于绘图:

met = 'Sensitivity'

fig, ax = plt.subplots(1,1,figsize=(15,8))
sns.boxplot(x='model', y=met, data=plot_df, ax=ax)
ax.set_xlabel('')

这导致 在此处输入图像描述

现在,我通过更改prefix和更改每个指标来为每个实验执行此操作,从而met生成 12 个这样的图。这对我来说太多了,无法在演示文稿中呈现,所以我需要一种更简洁地呈现这些结果的方法。

我在想,我是否可以为每个指标绘制一个图,并将每个实验的所有模型都放在 X 轴上(但宽度非常小),图像很宽,这样给定图中就有 33 个模型,所以我可以更容易地进行比较。我不知道该怎么做。我欢迎其他关于如何呈现这些结果的建议。

谢谢。

标签: pandasmatplotlibseaborn

解决方案


将三个数据集中的每一个组合(连接)到一个数据框中,每个数据集都有一个标识符变量列。然后你可以完全按照你的原样绘制它,但包括hue=用于分隔数据集的参数。这是一个例子。

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

dates = pd.date_range('01-01-2000', '12-31-2002')
df = pd.DataFrame(np.random.randint(0,100,size=(len(dates), 11)), columns=list('ABCDEFGHIJK'))
df.index = dates
df = df.reset_index().melt(value_vars=list('ABCDEFGHIJK'), id_vars='index').set_index('index')
df['dataset'] = df.index.year
df

>     index    variable value   dataset         
>     2000-01-01    A   47  2000
>     2000-01-02    A   89  2000
>     2000-01-03    A   79  2000
>     2000-01-04    A   24  2000
>     2000-01-05    A   87  2000
>     ...   ... ... ...
>     2002-12-27    K   62  2002
>     2002-12-28    K   67  2002
>     2002-12-29    K   46  2002
>     2002-12-30    K   62  2002
>     2002-12-31    K   73  2002
>     12056 rows × 3 columns

plt.figure(figsize=(10,8))
sns.boxplot(x = 'variable', y = 'value', hue = 'dataset', data = df)

在此处输入图像描述


推荐阅读