首页 > 解决方案 > 在 PySpark 中查找连续的每月注册期

问题描述

我正在尝试使用健康计划的成员 ID 和注册月份的 Spark DataFrame 来识别“连续”覆盖期,即连续几个月注册的成员。

下面是我在 PySpark 中使用的数据示例(sc是 SparkSession)。

import pandas as pd
import numpy as np

df = pd.DataFrame({'memid': ['123a', '123a', '123a', '123a', '123a', '123a',
                             '456b', '456b', '456b', '456b', '456b',
                             '789c', '789c', '789c', '789c', '789c', '789c'], 
                     'month_elig': ['2020-01-01', '2020-02-01', '2020-03-01', '2020-08-01', '2020-09-01', '2021-01-01',
                                    '2020-02-01', '2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01',
                                    '2020-02-01', '2020-03-01', '2020-04-01', '2020-05-01', '2020-06-01', '2020-07-01']})
df['month_elig'] = pd.to_datetime(test['month_elig'])
df['gap'] = (df.month_elig - df.groupby(['memid']).shift(1).month_elig)/np.timedelta64(1, 'M')
df['gap'] = np.where(df['gap'].isnull(), 0, df['gap'])
df['gap'] = np.round(df['gap'], 0)

scdf = sc.createDataFrame(df)

scdf.show()

#+-----+-------------------+---+
#|memid|         month_elig|gap|
#+-----+-------------------+---+
#| 123a|2020-01-01 00:00:00|0.0|
#| 123a|2020-02-01 00:00:00|1.0|
#| 123a|2020-03-01 00:00:00|1.0|
#| 123a|2020-08-01 00:00:00|5.0|
#| 123a|2020-09-01 00:00:00|1.0|
#| 123a|2021-01-01 00:00:00|4.0|
#| 456b|2020-02-01 00:00:00|0.0|
#| 456b|2020-05-01 00:00:00|3.0|
#| 456b|2020-06-01 00:00:00|1.0|
#| 456b|2020-07-01 00:00:00|1.0|
#| 456b|2020-08-01 00:00:00|1.0|
#| 789c|2020-02-01 00:00:00|0.0|
#| 789c|2020-03-01 00:00:00|1.0|
#| 789c|2020-04-01 00:00:00|1.0|
#| 789c|2020-05-01 00:00:00|1.0|
#| 789c|2020-06-01 00:00:00|1.0|
#| 789c|2020-07-01 00:00:00|1.0|
#+-----+-------------------+---+

如果我能够在 Pandas 中完成这个练习,我会使用下面的代码来创建unique_coverage_period字段。但是解决方案需要在 Spark 中,因为我正在处理的数据的大小,并且从我迄今为止的研究(例如)来看,像这样的迭代器方法并不是 Spark 真正设置的东西做。

a = 0
b = []
for i in df.gap.tolist():
    if i != 1:
      a += 1
      b.append(a)
    else:
      b.append(a)

df['unique_coverage_period'] = b

print(df)

#   memid month_elig  gap  unique_coverage_period
#0   123a 2020-01-01  0.0                       1
#1   123a 2020-02-01  1.0                       1
#2   123a 2020-03-01  1.0                       1
#3   123a 2020-08-01  5.0                       2
#4   123a 2020-09-01  1.0                       2
#5   123a 2021-01-01  4.0                       3
#6   456b 2020-02-01  0.0                       4
#7   456b 2020-05-01  3.0                       5
#8   456b 2020-06-01  1.0                       5
#9   456b 2020-07-01  1.0                       5
#10  456b 2020-08-01  1.0                       5
#11  789c 2020-02-01  0.0                       6
#12  789c 2020-03-01  1.0                       6
#13  789c 2020-04-01  1.0                       6
#14  789c 2020-05-01  1.0                       6
#15  789c 2020-06-01  1.0                       6
#16  789c 2020-07-01  1.0                       6

标签: pythondataframeapache-sparkpysparkapache-spark-sql

解决方案


从那以后,我想出了另一种方法来识别独特的覆盖期。虽然我发现@mck 发布的公认答案更加清晰和直接,但在处理包含 8460 万条记录的实际、更大的数据集时,下面提供的方法似乎执行得更快。

# Create a new DataFrame that retains only the coverage break months and then orders each month per member
w1 = Window().partitionBy('memid').orderBy( F.col('month_elig'))

scdf1 = scdf \
  .filter(F.col('gap') != 1) \
    .withColumn('rank', rank().over(w1)) \
  .select('memid', F.col('month_elig').alias('starter_month'), 'rank')

  
# Joins the two Spark Data Frames by memid and keeps only the records where the 'month_elig' is >= the 'starter_month' 
scdf2 = scdf.join(scdf1, on = 'memid', how = 'inner') \
  .withColumn('starter', F.when(F.col('month_elig') == F.col('starter_month'), 1) \
                  .otherwise(0)) \
  .filter(F.col('month_elig') >= F.col('starter_month'))
  

# If the 'month_elig' == 'starter_month', then keep that one, otherwise keep the latest 'starter_month' for each 'month_elig' record
w2 = Window().partitionBy(['memid', 'month_elig']).orderBy(F.col('starter').desc(), F.col('rank').desc())

scdf2 = scdf2 \
  .withColumn('rank', rank().over(w2)) \
  .filter(F.col('rank') == 1).drop('rank') \
  .withColumn('flag', F.concat(F.col('memid'), F.lit('_'), F.trunc(F.col('starter_month'), 'month'))) \
  .select('memid', 'month_elig', 'gap', 'flag')
  
scdf2.show()
+-----+-------------------+---+---------------+
|memid|         month_elig|gap|           flag|
+-----+-------------------+---+---------------+
| 789c|2020-02-01 00:00:00|0.0|789c_2020-02-01|
| 789c|2020-03-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-04-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-05-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-06-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-07-01 00:00:00|1.0|789c_2020-02-01|
| 123a|2020-01-01 00:00:00|0.0|123a_2020-01-01|
| 123a|2020-02-01 00:00:00|1.0|123a_2020-01-01|
| 123a|2020-03-01 00:00:00|1.0|123a_2020-01-01|
| 123a|2020-08-01 00:00:00|5.0|123a_2020-08-01|
| 123a|2020-09-01 00:00:00|1.0|123a_2020-08-01|
| 123a|2021-01-01 00:00:00|4.0|123a_2021-01-01|
| 456b|2020-02-01 00:00:00|0.0|456b_2020-02-01|
| 456b|2020-05-01 00:00:00|3.0|456b_2020-05-01|
| 456b|2020-06-01 00:00:00|1.0|456b_2020-05-01|
| 456b|2020-07-01 00:00:00|1.0|456b_2020-05-01|
| 456b|2020-08-01 00:00:00|1.0|456b_2020-05-01|
+-----+-------------------+---+---------------+

推荐阅读