首页 > 解决方案 > 使用 group by 检查 pandas 未来行中的条件

问题描述

以下是我的数据框的样子,Expected_Output也是我想要的列。

  Group  Signal  Value1  Value2  Expected_Output
0      1       0       3       1              NaN
1      1       1       4       2              NaN
2      1       0       7       4              NaN
3      1       0       8       9              1.0
4      1       0       5       3              NaN
5      2       1       3       6              NaN
6      2       1       1       2              1.0
7      2       0       3       4              1.0

对于给定Group的 if Signal == 1,然后我试图查看接下来的三行(而不是当前行)并检查 if Value1 < Value2Expected_Output如果该条件为真,则我在列中返回 1 。例如,如果Value < Value2条件由于多种原因而满足,因为它位于Signal == 1第 5 行和第 6 行(Group 2)中的下 3 行之内,那么我也将返回 1 in Expected_Output

我假设group by object, np.where, any,的正确组合shift可能是解决方案,但不能完全到达那里。

注意:- 亚历山大在评论中指出了一个冲突。理想情况下,由于前一行中的信号而设置的值将取代给定行中的当前行规则冲突。

标签: pythonpandasnumpy

解决方案


如果您要检查很多以前的行,多个班次很快就会变得混乱,但这里还不错:

s = df.groupby('Group').Signal

condition = ((s.shift(1).eq(1) | s.shift(2).eq(1) | s.shift(3).eq(1)) 
                & df.Value1.lt(df.Value2))

df.assign(out=np.where(condition, 1, np.nan))

   Group  Signal  Value1  Value2  out
0      1       0       3       1  NaN
1      1       1       4       2  NaN
2      1       0       7       4  NaN
3      1       0       8       9  1.0
4      1       0       5       3  NaN
5      2       1       3       6  NaN
6      2       1       1       2  1.0
7      2       0       3       4  1.0

如果您担心使用这么多班次的性能,我不会太担心,这里有一个 100 万行的示例:

In [401]: len(df)
Out[401]: 960000

In [402]: %%timeit
     ...: s = df.groupby('Group').Signal
     ...:
     ...: condition = ((s.shift(1).eq(1) | s.shift(2).eq(1) | s.shift(3).eq(1))
     ...:                 & df.Value1.lt(df.Value2))
     ...:
     ...: np.where(condition, 1, np.nan)
     ...:
     ...:
94.5 ms ± 524 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@Alexander 确定了规则中的冲突,这是一个使用符合该要求的掩码的版本:

s = (df.Signal.mask(df.Signal.eq(0)).groupby(df.Group)
        .ffill(limit=3).mask(df.Signal.eq(1)).fillna(0))

现在您可以简单地将此列与您的其他条件一起使用:

np.where((s.eq(1) & df.Value1.lt(df.Value2)).astype(int), 1, np.nan)

array([nan, nan, nan,  1., nan, nan, nan,  1.])

推荐阅读