首页 > 解决方案 > 如何从 matplotlib 中具有多个分类列的数据框制作线图

问题描述

我想为不同的类别制作折线图,其中一个是不同的国家,一个是不同的国家,用于每周的折线图。最初,我能够使用绘制线图,seaborn但它不是很方便,比如设置它的标签、图例、调色板等。我想知道是否有任何方法可以轻松地使用多个分类变量重塑这些数据并呈现折线图。在最初的尝试中,我尝试过,seaborn.relplot但调整它的参数并不容易,而且很难自定义结果图。谁能指出我用多个分类列重塑数据框并呈现清晰折线图的任何有效方法?有什么想法吗?

可重现的数据和我的尝试

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
dff = pd.read_csv(url, parse_dates=['weekly'])
dff.drop('Unnamed: 0', axis=1, inplace=True)

df2_bf = dff.groupby(['destination', 'weekly'])['FCF_Beef'].sum().unstack()
df2_bf = df2_bf.fillna(0)
mm = df2_bf.T
mm.columns.name = None
mm = mm[~(mm.isna().sum(1)/mm.shape[1]).gt(0.9)].fillna(0)

#Total sum per column: 
mm.loc['Total',:]= mm.sum(axis=0)
mm1 = mm.T
mm1 = mm1.nlargest(6, columns=['Total'])
mm1.drop('Total', axis=1, inplace=True)
mm2 = mm1.T
mm2.reset_index(inplace=True)
mm2['weekly'] = pd.to_datetime(mm2['weekly'])

mm2['year'] = mm2['weekly'].dt.year
mm2['week'] = mm2['weekly'].dt.isocalendar().week
df = mm2.melt(id_vars=['weekly','week','year'], var_name='country')

df_ = df.groupby(['country', 'year', 'week'], as_index=False)['value'].sum()
sns.relplot(data=df_, x='week', y='value', hue='year', row='country', kind='line', height=6, aspect=2, facet_kws={'sharey': False, 'sharex': False}, sizes=(20, 10))

当前地块

这是我制作的当前情节之一seaborn.relplot

情节的结构对我来说还可以,但是在 中seaborn.replot,很难调整参数并且它和使用一样灵活matplotlib。另外,我意识到聚合我的数据的方式不是很有效。我认为可能有一种捷径可以使上述代码片段更高效,例如:

plt_data = []
for i in dff.loc[:, ['FCF_Beef','FCF_Beef']]:
    ...

但是这样做我遇到了几个问题来制作正确的情节。谁能指出我如何使这个简单有效,以便使用 matplotlib 制作预期的折线图?有谁知道这样做的更好方法?任何想法?谢谢

期望的输出

在我想要的情节中,首先我需要迭代国家列表,其中每个国家都有一个子情节,在每个子情节中,x 轴显示 52 周,y 轴显示weeklyExport每个国家不同年份的数量。这是我用seaborn.relplot.

请注意,我不喜欢 的输出seaborn.relplot,所以我想知道如何通过尝试使上述尝试更有效matplotlib。任何想法?

标签: pythonpandasmatplotlibplot

解决方案


  • 根据 OP 的要求,以下是绘制数据的迭代方式。
  • 以下示例每年绘制'destination'单个图中给定的图
  • 这类似于这个问题的答案
import pandas as pd
import matplotlib.pyplot as plt

# load the data
url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
df = pd.read_csv(url, parse_dates=['weekly'], usecols=range(1, 6))

# groupby destination and iterate through for plotting
for g, d in df.groupby(['destination']):

    # create the figure
    fig, ax = plt.subplots(figsize=(7, 4))
    
    # add lines for specific years
    for year in d.weekly.dt.year.unique():
        data = d[d.weekly.dt.year == year].copy()  # select the data from d, by year
        data['week'] = data.weekly.dt.isocalendar().week  # create a week column
        data.sort_values('weekly', inplace=True)
        display(data.head())  # display is for jupyter, if it causes an error, use pring
        data.plot(x='week', y='FCF_Beef', ax=ax, label=year)
    
    plt.show()
  • 单样本图

在此处输入图像描述

  • 如果我们查看其中一个数据框的尾部,data.weekly.dt.isocalendar().week将一年中的最后一天作为week 1,则将一条线画回到放置在第 1 周的最后一个数据点。
  • 根据这个已关闭的 pandas bug ,此功能依赖于datetime.datetime(2018, 12, 31).isocalendar()并且是模块的预期行为。datetime

在此处输入图像描述

  • 用 , 删除最后一行.iloc[:-1, :]是一种解决方法
  • 或者,替换data['week'] = data.weekly.dt.isocalendar().weekdata['week'] = data.weekly.dt.strftime('%W').astype('int')
data.iloc[:-1, :].plot(x='week', y='FCF_Beef', ax=ax, label=year)

在此处输入图像描述

更新了 OP 的所有代码

# load the data
url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
dff = pd.read_csv(url, parse_dates=['weekly'], usecols=range(1, 6))

df2_bf = dff.groupby(['destination', 'weekly'])['FCF_Beef'].sum().unstack()
df2_bf = df2_bf.fillna(0)
mm = df2_bf.T
mm.columns.name = None
mm = mm[~(mm.isna().sum(1)/mm.shape[1]).gt(0.9)].fillna(0)

#Total sum per column: 
mm.loc['Total',:]= mm.sum(axis=0)
mm1 = mm.T
mm1 = mm1.nlargest(6, columns=['Total'])
mm1.drop('Total', axis=1, inplace=True)
mm2 = mm1.T
mm2.reset_index(inplace=True)
mm2['weekly'] = pd.to_datetime(mm2['weekly'])

mm2['year'] = mm2['weekly'].dt.year
mm2['week'] = mm2['weekly'].dt.strftime('%W').astype('int')
df = mm2.melt(id_vars=['weekly','week','year'], var_name='country')

# groupby destination and iterate through for plotting
for g, d in df.groupby(['country']):

    # create the figure
    fig, ax = plt.subplots(figsize=(7, 4))
    
    # add lines for specific years
    for year in d.weekly.dt.year.unique():
        data = d[d.weekly.dt.year == year].copy()  # select the data from d, by year
        data.sort_values('weekly', inplace=True)
        display(data.head())  # display is for jupyter, if it causes an error, use pring
        data.plot(x='week', y='value', ax=ax, label=year, title=g)
    
    plt.show()

推荐阅读