首页 > 解决方案 > pandas groupby 过滤和采样

问题描述

我有一个数据框df,我想按某些列进行分组,并每个组中采样n_items 。 小于n_items 的组应该被忽略并且不被采样。

import numpy as np
import pandas as pd
n = int(1e3)
n_prod = 5
p = [0.05, 0.05, 0.3, 0.3, 0.3]
df = pd.DataFrame(
    {
        "Product": np.random.choice([f"Product_{i}" for i in range(n_prod)], n, p=p),
         "Price": (np.random.random(n) * 50 + 10).round(2),
     }
 )

我找到了一个解决方案,但我认为可能存在一个更Pythonic的解决方案,不需要对数据框进行两次分组

min_size = 150
df_sample = df.groupby("Product").filter(lambda x: len(x) >= min_size).groupby('Product').sample(min_size)
df_sample.shape

输出:

(450, 2)

编辑:请注意,单个groupby将从整个数据集中采样 150 行

df.groupby("Product").filter(lambda x: len(x) >= min_size).sample(min_size).shape

输出:

(150, 2)

标签: pythonpandaspandas-groupby

解决方案


为了使用单个groupby,我会将所有逻辑放在一个函数中并应用它:

def sample(g, size):
    return g.sample(size) if g.shape[0] >= size else None

df_sample = df.groupby('Product', group_keys=False).apply(sample, size=min_size)

请注意,这仅比原始版本(双倍groupby:2.77 毫秒)快一点(2.38 毫秒)。

另一种方法是用来value_counts()获取组大小。正如@BlackMath 所指出的,value_counts()groupby(...).filter(...). 它也比groupby(...).size(). 但下面的最终结果时间与您的原始解决方案相当(我的机器上为 2.69 毫秒与 2.77 毫秒):

s = df['Product'].value_counts()
df_sample = df.loc[df['Product'].isin(
    s.index[s >= min_size])].groupby('Product').sample(min_size)

推荐阅读