python - pandas groupby 过滤和采样
问题描述
我有一个数据框df
,我想按某些列进行分组,并从每个组中采样n_items 。
小于n_items 的组应该被忽略并且不被采样。
import numpy as np
import pandas as pd
n = int(1e3)
n_prod = 5
p = [0.05, 0.05, 0.3, 0.3, 0.3]
df = pd.DataFrame(
{
"Product": np.random.choice([f"Product_{i}" for i in range(n_prod)], n, p=p),
"Price": (np.random.random(n) * 50 + 10).round(2),
}
)
我找到了一个解决方案,但我认为可能存在一个更Pythonic的解决方案,不需要对数据框进行两次分组。
min_size = 150
df_sample = df.groupby("Product").filter(lambda x: len(x) >= min_size).groupby('Product').sample(min_size)
df_sample.shape
输出:
(450, 2)
编辑:请注意,单个groupby将从整个数据集中采样 150 行
df.groupby("Product").filter(lambda x: len(x) >= min_size).sample(min_size).shape
输出:
(150, 2)
解决方案
为了使用单个groupby
,我会将所有逻辑放在一个函数中并应用它:
def sample(g, size):
return g.sample(size) if g.shape[0] >= size else None
df_sample = df.groupby('Product', group_keys=False).apply(sample, size=min_size)
请注意,这仅比原始版本(双倍groupby
:2.77 毫秒)快一点(2.38 毫秒)。
另一种方法是用来value_counts()
获取组大小。正如@BlackMath 所指出的,value_counts()
比groupby(...).filter(...)
. 它也比groupby(...).size()
. 但下面的最终结果时间与您的原始解决方案相当(我的机器上为 2.69 毫秒与 2.77 毫秒):
s = df['Product'].value_counts()
df_sample = df.loc[df['Product'].isin(
s.index[s >= min_size])].groupby('Product').sample(min_size)