pandas - 包含大量信息的 Matplotlib/Seaborn 箱线图
问题描述
我需要绘制 3 个不同数据集的实验结果,其中 11 个模型超过 4 个指标,每个指标有 100 个数据点。我知道将所有这些信息整合到一个情节中非常困难,而且可能更难以阅读。目前,我有 12 个图,每个实验/指标 (3*4) 一个,所有 11 个模型的结果都在一个图中。我试图减少这 4 个图:每个指标一个,11 个模型和 3 个实验在同一个图中。
更具体地说,我使用的指标是敏感性、特异性、PPV 和 AUC。有 3 个不同的实验:notes_common、full_common_vital、full_common_all。我有 11 个模型。目前,这是我必须准备好数据以进行绘图的代码:
prefix = 'notes_common_vital'
bams = pickle.load(open(workdir/f'{prefix}_bams.pkl', 'rb'))
for k in bams.keys():
bams[k.upper()] = bams.pop(k)
bams['AVG-ALL'] = bams.pop('AVG-LR-RF-GBM')
bams['MAX-ALL'] = bams.pop('MAX-LR-RF-GBM')
itr = iter(bams.keys())
bams.keys()
metrics = {}
for md in itr:
df = pd.DataFrame()
for k, m in bams[md].yield_metrics():
df[k] = m
df['model'] = md
cols = list(df.columns)
cols = [cols[-1]] + cols[:-1]
df = df[cols]
metrics[md] = df
plot_df = pd.concat(metrics.values())
bams
只是我创建的自定义类的一个对象,用于保存超过 100 次迭代的二进制平均指标。
plot_df.shape
(1100, 5)
plot_df.columns
Index(['model', 'Sensitivity', 'Specificity', 'PPV', 'AUC'], dtype='object')
plot_df.head()
model Sensitivity Specificity PPV AUC
0 LR 0.782575 0.607646 0.389910 0.763138
1 LR 0.810860 0.537603 0.362753 0.752767
2 LR 0.823888 0.598635 0.341402 0.784208
3 LR 0.810928 0.617947 0.356734 0.782843
4 LR 0.833948 0.553218 0.333702 0.765500
对于绘图:
met = 'Sensitivity'
fig, ax = plt.subplots(1,1,figsize=(15,8))
sns.boxplot(x='model', y=met, data=plot_df, ax=ax)
ax.set_xlabel('')
现在,我通过更改prefix
和更改每个指标来为每个实验执行此操作,从而met
生成 12 个这样的图。这对我来说太多了,无法在演示文稿中呈现,所以我需要一种更简洁地呈现这些结果的方法。
我在想,我是否可以为每个指标绘制一个图,并将每个实验的所有模型都放在 X 轴上(但宽度非常小),图像很宽,这样给定图中就有 33 个模型,所以我可以更容易地进行比较。我不知道该怎么做。我欢迎其他关于如何呈现这些结果的建议。
谢谢。
解决方案
将三个数据集中的每一个组合(连接)到一个数据框中,每个数据集都有一个标识符变量列。然后你可以完全按照你的原样绘制它,但包括hue=
用于分隔数据集的参数。这是一个例子。
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
dates = pd.date_range('01-01-2000', '12-31-2002')
df = pd.DataFrame(np.random.randint(0,100,size=(len(dates), 11)), columns=list('ABCDEFGHIJK'))
df.index = dates
df = df.reset_index().melt(value_vars=list('ABCDEFGHIJK'), id_vars='index').set_index('index')
df['dataset'] = df.index.year
df
> index variable value dataset
> 2000-01-01 A 47 2000
> 2000-01-02 A 89 2000
> 2000-01-03 A 79 2000
> 2000-01-04 A 24 2000
> 2000-01-05 A 87 2000
> ... ... ... ...
> 2002-12-27 K 62 2002
> 2002-12-28 K 67 2002
> 2002-12-29 K 46 2002
> 2002-12-30 K 62 2002
> 2002-12-31 K 73 2002
> 12056 rows × 3 columns
plt.figure(figsize=(10,8))
sns.boxplot(x = 'variable', y = 'value', hue = 'dataset', data = df)
推荐阅读
- mysql - 如何在一张表上查找不同 ID 的数量
- javascript - 如何使路由自动转到反应路由器中的嵌套路由路径
- javascript - 如何在 jQuery 中显示视频播放器的持续时间?格式如下:00:00
- angular - 如何在 Angular Material 的单列中添加多个值?
- google-sheets - 在谷歌表格中添加新标签时,如何在主表格上自动填充信息?
- git - Git 删除提交而不删除合并提交
- python-3.x - Python Pytest 模拟三个函数
- javascript - 浏览器如何暴露当前正在播放的音频?
- java - Spring Data JPA 不会在运行模式下持久化对象,但它会在调试模式下持久化
- javascript - 获取脚本中 var 的值