首页 > 解决方案 > 使用 sns lineplot 绘制平均线

问题描述

我有一个看起来像这样的数据框:

id|date    |amount
1 |02-04-18|3000
1 |05-04-19|5000
1 |10-04-19|2600
2 |10-04-19|2600
2 |11-04-19|3000

我想为每个唯一 ID 随着时间的推移花费的金额,并有一个平均趋势线。这是我拥有的代码:

import matplotlib.pyplot as plt
import pandas as pd

temp_m = df.pivot_table(index='id',columns='id',values='amount', fill_value=0)
temp_m = pd.melt(temp, id_vars=['id'])
temp_m['date'] = temp_m['date'].astype('str')
fig, ax = plt.subplots(figsize=(20,10))
for i, group in temp_m.groupby('id'):
    group.plot('id', y='amount', ax=ax,legend=None)
    plt.xticks(rotation = 90)

在此处输入图像描述

每条生产线都是一个独特的客户。

目标:我想添加另一条线,它是所有单个客户趋势的平均值。

另外,如果还有更好的方法来绘制各条线,请告诉我

标签: pythonpandasmatplotlibseaborn

解决方案


首先我们重塑数据

agg = df.set_index(['date', 'id']).unstack()
agg.columns = agg.columns.get_level_values(-1)

这使得绘图非常容易:

sns.lineplot(data=agg)

平均趋势可以通过以下方式计算

from sklearn.linear_model import LinearRegression

regress = {}
idx = agg.index.to_julian_date()[:, None]
for c in agg.columns:
    regress[c] = LinearRegression().fit(idx, agg[c].fillna(0)).predict(idx)

trend = pd.Series(pd.DataFrame(regress).mean(axis=1).values, agg.index)

推荐阅读