python - Removing rows following (and including) the first occurrence of a particular column value
问题描述
I have a very large pd.Dataframe contains millions of records where PID
and Ses_ID
are both index columns, and Var_3
indicates the occurrence of some event.
PID | Ses_ID | Var_1 | Var_2 | Var_3 |
---|---|---|---|---|
001 | 001 002 003 |
0.7 0.8 0.9 |
0.5 0.4 0.3 |
0 1 0 |
002 | 004 005 006 007 008 |
0.8 0.7 0.8 0.2 0.8 |
0.2 0.1 0.7 0.2 0.2 |
0 0 1 0 1 |
I want to remove/filter out sessions following and including the first occurrence of Var_3==1
from each person's (indexed by PID
) records. Thus the provided example would result as:
PID | Ses_ID | Var_1 | Var_2 | Var_3 |
---|---|---|---|---|
001 | 001 | 0.7 | 0.5 | 0 |
002 | 004 005 |
0.8 0.7 |
0.2 0.1 |
0 0 |
I could iteratively add relevant sessions and corresponding PID
to a new dataframe but that would be extremely time-consuming given the size of the current dataframe. What would be an efficient way of achieving this? Many thanks!
Updated situation: I have found many rows have the same Ses_ID
. How do I remove sessions following (and including) the first occurrence of a particular column value? So for the example below, both rows for Ses_ID==005
would be removed because the event of Var_3==1
occurred in this session.
PID | Ses_ID | Var_1 | Var_2 | Var_3 |
---|---|---|---|---|
001 | 001 002 003 |
0.7 0.8 0.9 |
0.5 0.4 0.3 |
0 1 0 |
002 | 009 004 004 005 005 006 007 |
0.1 0.8 0.8 0.7 0.8 0.2 0.8 |
0.3 0.1 0.2 0.1 0.7 0.2 0.2 |
0 0 0 0 1 0 1 |
should be transformed to:
PID | Ses_ID | Var_1 | Var_2 | Var_3 |
---|---|---|---|---|
001 | 001 | 0.7 | 0.5 | 0 |
002 | 009 004 004 |
0.1 0.8 0.8 |
0.3 0.1 0.2 |
0 0 0 |
解决方案
您可以尝试使用布尔索引:
# assuming PID, Ses_ID are indices:
mask = df.groupby(level=0)["Var_3"].cumsum().eq(0)
print(df[mask])
印刷:
Var_1 Var_2 Var_3
PID Ses_ID
1 1 0.7 0.5 0
2 4 0.8 0.2 0
5 0.7 0.1 0
编辑:
g = df.groupby(level=0)
df["Var_3"] = g["Var_3"].transform(
lambda x: x.groupby(level=1).transform(sorted, reverse=True)
)
mask = g["Var_3"].cumsum().eq(0)
print(df[mask])
印刷:
Var_1 Var_2 Var_3
PID Ses_ID
1 1 0.7 0.5 0
2 4 0.8 0.2 0
推荐阅读
- wso2 - 执行 wso2is REST API
- android - 多次调用实时数据观察者
- python - 如何将搜索到的每个项目的链接存储到列表中?
- javascript - 任务 - 濒危物种。如何解决这个问题呢?为什么我无法读取未定义的属性“长度”。问题描述如下
- node.js - 在 express 应用程序中通过客户端 js 中的 CDN 导入模块
- java - Firebase Database Android:查询数据需要帮助
- next.js - Next.js - 动态导入与等待导入
- python-3.x - distutils.spawn.find_executable 和 shutil.which 找不到文件
- sql - 具有数组到列的列 (Bigquery)
- javascript - 将firebase中的集合转换为nodejs中的数组