首页 > 解决方案 > 如何在 pandas groupby 操作中计算 lambda 函数并同时添加一个 size 列?

问题描述

我有一个如下所示的熊猫数据框:

df = pd.DataFrame({"A": [1, 1, 1, 2, 2],
                       "B": ["apple", "apple", "banana", "pineapple", "pineapple"],
                       "C": [[6, 5, 2], [2, 10, 2], [5, 37, 1], [4, 19, 2], [1, 5, 1]]})

我想在C计算列表平均值的列表列上执行 lambda 函数。我知道我可以通过以下方式之一做到这一点:

1:

df = df.groupby(['A', 'B'])['C'].apply(lambda s: np.array(list(s)).mean(axis=0)).reset_index()

2:

s = df.set_index(['A', 'B'])
out = pd.DataFrame(list(s['C']), s.index).mean(level=[0, 1])
out.drop(out.columns.tolist(), 1).assign(C=out.values.tolist()).reset_index()

除此之外,添加一个包含每个组大小的列。我知道我可以使用以下代码添加这样的大小列:

df = df[['A', 'B']].groupby(['A', 'B']).size()

但无法找到如何做到这两点。生成的 DataFrame 如下所示:

A        B            C              count
1        apple        [4, 7.5, 2]    2
1        banana       [5, 37, 1]     1
2        pineapple    [2.5, 12, 1.5] 2

如何以尽可能少的时间解决这个问题?我的真实数据框非常庞大,因此所有操作都需要尽可能节省时间

标签: pythonpandasdataframepandas-groupby

解决方案


选项 1: Groupby agg:

import numpy as np
import pandas as pd

df = pd.DataFrame({
    "A": [1, 1, 1, 2, 2],
    "B": ["apple", "apple", "banana", "pineapple", "pineapple"],
    "C": [[6, 5, 2], [2, 10, 2], [5, 37, 1], [4, 19, 2], [1, 5, 1]]
})

df = df.groupby(['A', 'B'])['C'].agg(
    [lambda s: np.array(list(s)).mean(axis=0).tolist(), 'size']
).reset_index().rename(columns={'<lambda_0>': 'C', 'size': 'count'})

print(df)

df

   A          B                 C  count
0  1      apple   [4.0, 7.5, 2.0]      2
1  1     banana  [5.0, 37.0, 1.0]      1
2  2  pineapple  [2.5, 12.0, 1.5]      2

选项 2 pd.concat

g = df.groupby(['A', 'B'])['C']

df = pd.concat(
    (
        g.apply(lambda s: np.array(list(s)).mean(axis=0)),
        g.size()
    ), axis=1,
    keys=['C', 'count']
).reset_index()

df

   A          B                 C  count
0  1      apple   [4.0, 7.5, 2.0]      2
1  1     banana  [5.0, 37.0, 1.0]      1
2  2  pineapple  [2.5, 12.0, 1.5]      2

通过 perfplot 的一些时间信息:

当前答案的时序图


import numpy as np
import pandas as pd
import perfplot

np.random.seed(5)


def gen_data(n):
    df = pd.DataFrame({"A": [1, 1, 1, 2, 2],
                       "B": ["apple", "apple", "banana", "pineapple",
                             "pineapple"],
                       "C": [[6, 5, 2], [2, 10, 2], [5, 37, 1], [4, 19, 2],
                             [1, 5, 1]]})

    return pd.concat([df.assign(B=df['B'] + str(i)) for i in range(n)],
                     ignore_index=True)


def chain_assign(df):
    return (df.groupby(['A', 'B'])['C']
            .apply(lambda s: np.array(list(s)).mean(axis=0))
            .reset_index()
            .assign(count=df.groupby(['A', 'B'])['C'].size().values))


def to_frame_value_counts(df):
    return (df.groupby(['A', 'B'])['C']
            .apply(lambda s: np.array(list(s)).mean(axis=0))
            .to_frame('C')
            .assign(count=df[['A', 'B']].value_counts())
            .reset_index())


def pd_concat(df):
    g = df.groupby(['A', 'B'])['C']

    return pd.concat(
        (
            g.apply(lambda s: np.array(list(s)).mean(axis=0)),
            g.size()
        ), axis=1,
        keys=['C', 'count']
    ).reset_index()


def groupby_agg(df):
    return df.groupby(['A', 'B'])['C'].agg(
        [lambda s: np.array(list(s)).mean(axis=0).tolist(), 'size']
    ).reset_index().rename(columns={'<lambda_0>': 'C', 'size': 'count'})


def agg_len_apply(df):
    df = df.groupby(['A', 'B'], as_index=False).agg({'C': list})
    df['count'] = df['C'].str.len()
    df['C'] = df['C'].apply(lambda c: np.array(c).mean(axis=0))
    return df


def agg_map(df):
    return df.groupby('B')['C'].agg(
        mean=(lambda x: x.map(np.array).mean().tolist()),
        size=(lambda x: x.size))


if __name__ == '__main__':
    out = perfplot.bench(
        setup=gen_data,
        kernels=[
            chain_assign,
            pd_concat,
            groupby_agg,
            agg_len_apply,
            to_frame_value_counts,
            agg_map
        ],
        labels=[
            'chain_assign (Anurag Dabas)',
            'to_frame_value_counts (Anurag Dabas)',
            'pd_concat (Henry Ecker)',
            'groupby_agg (Henry Ecker)',
            'agg_len_apply (SomeDude)',
            'agg_map (rhug123)'
        ],
        n_range=[2 ** k for k in range(20)],
        equality_check=None
    )
    out.save('perfplot_results.png', transparent=False)

推荐阅读