首页 > 解决方案 > 绘制不完整年份的多个折线图

问题描述

我有一个跨越 4 年的数据集,我想在图表上绘制,每一年都是一个单独的系列。我的数据是从 2015 年 3 月到 2018 年 8 月的每日详细信息,我想按月汇总和显示。

plt.clf() # clear figures
plt.figure(figsize=(16,8)) 

x = np.arange(0, 12, 1)
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
total_sales_2015 = train.loc[train['Year'] == '2015'].groupby('Month')['SalesValue'].sum()/1000.0 # format into thousands
total_sales_2016 = train.loc[train['Year'] == '2016'].groupby('Month')['SalesValue'].sum()/1000.0 # format into thousands
total_sales_2017 = train.loc[train['Year'] == '2017'].groupby('Month')['SalesValue'].sum()/1000.0 # format into thousands
total_sales_2018 = train.loc[train['Year'] == '2018'].groupby('Month')['SalesValue'].sum()/1000.0 # format into thousands

plt.plot(x, total_sales_2015, label="2015") <-- doesn't work, as only 10 data points
plt.plot(x, total_sales_2016, label="2016") <-- does work
plt.plot(x, total_sales_2017, label="2017") <-- does work
plt.plot(x, total_sales_2018, label="2018") <-- doesn't work, as only 8 data points

如何在图表上显示部分年份?当我运行上面的代码时,它会产生以下错误:“ValueError:x 和 y 必须具有相同的第一维”

标签: pythonpandasmatplotlibplot

解决方案


您可以使用Series.reindex由以下创建的所有可能的索引MultiIndex.from_product

np.random.seed(123)
train = pd.DataFrame({'Year':['2015'] * 10 + ['2018'] * 8,
                      'Month': list(range(3, 13)) + list(range(1, 9)),
                      'SalesValue':np.random.randint(1000, size=18)})
train['Month'] = train['Month'].astype(str).str.zfill(2)
print (train)
    Year Month  SalesValue
0   2015    03         510
1   2015    04         365
2   2015    05         382
3   2015    06         322
4   2015    07         988
5   2015    08          98
6   2015    09         742
7   2015    10          17
8   2015    11         595
9   2015    12         106
10  2018    01         123
11  2018    02         569
12  2018    03         214
13  2018    04         737
14  2018    05          96
15  2018    06         113
16  2018    07         638
17  2018    08          47

total_sales = train.groupby(['Year','Month'])['SalesValue'].sum() / 1000

years = np.arange(2015, 2019).astype(str)
months = pd.Series(np.arange(1, 13, 1)).astype(str).str.zfill(2)

mux = pd.MultiIndex.from_product([years, months], names=total_sales.index.names)

total_sales = total_sales.reindex(mux)

print (total_sales)

Year  Month
2015  01         NaN
      02         NaN
      03       0.510
      04       0.365
      05       0.382
      06       0.322
      07       0.988
      08       0.098
      09       0.742
      10       0.017
      11       0.595
      12       0.106
2016  01         NaN
      02         NaN
      03         NaN
      04         NaN
      05         NaN
      06         NaN
      07         NaN
      08         NaN
      09         NaN
      10         NaN
      11         NaN
      12         NaN
2017  01         NaN
      02         NaN
      03         NaN
      04         NaN
      05         NaN
      06         NaN
      07         NaN
      08         NaN
      09         NaN
      10         NaN
      11         NaN
      12         NaN
2018  01       0.123
      02       0.569
      03       0.214
      04       0.737
      05       0.096
      06       0.113
      07       0.638
      08       0.047
      09         NaN
      10         NaN
      11         NaN
      12         NaN
Name: SalesValue, dtype: float64

plt.plot(x, total_sales.loc['2015'], label="2015")
plt.plot(x, total_sales.loc['2016'], label="2016")
plt.plot(x, total_sales.loc['2017'], label="2017")
plt.plot(x, total_sales.loc['2018'], label="2018")

如果可能的值x-axis是月份,请Series.unstack与 一起使用DataFrame.plot

plt.figure(figsize=(16,8)) 
total_sales.unstack(level=0).plot()

推荐阅读