首页 > 解决方案 > 将重叠区间列表拆分为 pyspark 数据帧中的非重叠子区间

问题描述

我有一个 pyspark 数据框,其中包含start_time定义end_time行间隔的列

有一个 column rate,我想知道子区间是否没有不同的值(根据定义重叠);如果是这样,我想保留最后的记录作为基本事实。

输入:

# So this:
input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10),  # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20          
              Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),  # OVERLAP: full overlap for (2,3) with (1,4)               
              Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20),  # OVERLAP: (3,5) and (1,4) and rate=10/20                          
              Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),  # NO OVERLAP: hole between (5,6)                                            
              Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)]  # NO OVERLAP

df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+
    |         start_time|           end_time|rate|
    +-------------------+-------------------+----+
    |2018-01-01 00:00:00|2018-01-04 00:00:00|  10|
    |2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
    |2018-01-03 00:00:00|2018-01-05 00:00:00|  20|
    |2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
    |2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
    +-------------------+-------------------+----+

# To give you: 
output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
               Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
               Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
               Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
               Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
               Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
              ]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>> 
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-04 00:00:00|  20|
|2018-01-04 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

标签: pythonapache-sparkpysparkapache-spark-sql

解决方案


您可以将 end_time 与下一个 start_time 进行比较,如果后者小于前者,则将 end_time 替换为下一个 start_time。

from pyspark.sql import functions as F, Window

df2 = df.withColumn(
    'end_time2', 
    F.min('start_time').over(
        Window.orderBy('start_time')
              .rowsBetween(1, Window.unboundedFollowing)
    )
).select(
    'start_time',
    F.when(
        F.col('end_time2') < F.col('end_time'), 
        F.col('end_time2')
    ).otherwise(
        F.col('end_time')
    ).alias('end_time'),
    'rate'
)

df2.show()
+-------------------+-------------------+----+
|         start_time|           end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00|  10|
|2018-01-02 00:00:00|2018-01-03 00:00:00|  10|
|2018-01-03 00:00:00|2018-01-05 00:00:00|  20|
|2018-01-06 00:00:00|2018-01-07 00:00:00|  30|
|2018-01-07 00:00:00|2018-01-08 00:00:00|  30|
+-------------------+-------------------+----+

推荐阅读