首页 > 解决方案 > 如何在 seaborn 的 lmplot 中添加比较线?

问题描述

我想结合我拥有的以下 lmplots。更具体地说,红线是每个季节的平均值,我想将它们与其他数据放在各自的 lmplots 上,而不是将它们分开。这是我的代码(注意,轴限制不起作用,因为第二个 lmplot 把它弄乱了。当我只绘制初始数据时它起作用):

ax = sns.lmplot(data=data, x='air_yards', y='cpoe',col='season', lowess = True, scatter_kws={'alpha':.6, 'color': '#4F2E84'}, line_kws={'alpha':.6, 'color': '#4F2E84'})

ax = sns.lmplot(data=avg, x='air_yards', y= 'cpoe',lowess=True, scatter=False, line_kws={'linestyle':'--', 'color': 'red'}, col = 'season')

axes.set_xlim([-5,30])
axes.set_ylim([-25,25])

ax.set(xlabel='air yards')

这是输出。简单地说,我想把那些红线放在上面各自的年份图上。谢谢! 这是输出

标签: pythonpandasmatplotlibplotseaborn

解决方案


不确定它是否可能是您想要的方式,所以可能是这样的:

import matplotlib.pyplot as plt
import seaborn as sns

#dummy example
data = pd.DataFrame({'air_yards': range(1,11), 
                     'cpoe': range(1,11), 
                     'season': [1,2,3,2,1,3,2,1,3,2]})
avg = pd.DataFrame({'air_yards': [1, 10]*3, 
                    'cpoe': [2,2,5,5,8,8], 
                    'season': [1,1,2,2,3,3]})

# need this info
n = data["season"].nunique()

# create the number of subplots
fig, axes = plt.subplots(ncols=n, sharex=True, sharey=True)

# now you need to loop through unique season
for ax, (season, dfg) in zip(axes.flat, data.groupby("season")):
    # set title
    ax.set_title(f'season={season}')

    # create the replot for data
    sns.regplot("air_yards", "cpoe", data=dfg, ax=ax, 
                lowess = True, scatter_kws={'alpha':.6, 'color': '#4F2E84'}, 
                line_kws={'alpha':.6, 'color': '#4F2E84'})

    # create regplot for avg
    sns.regplot("air_yards", "cpoe", data=avg[avg['season'].eq(season)], ax=ax, 
                lowess=True, scatter=False, 
                line_kws={'linestyle':'--', 'color': 'red'})

plt.show()

你得到 在此处输入图像描述


推荐阅读