首页 > 解决方案 > 使用 Window 函数折叠 DataFrame

问题描述

我想根据 ID 列折叠数据框中的行,并使用窗口函数计算每个 ID 的记录数。这样做,我想避免按 ID 对窗口进行分区,因为这会导致大量的分区。

我有一个表格的数据框

+----+-----------+-----------+-----------+
| ID | timestamp | metadata1 | metadata2 |
+----+-----------+-----------+-----------+
|  1 | 09:00     | ABC       | apple     |
|  1 | 08:00     | NULL      | NULL      |
|  1 | 18:00     | XYZ       | apple     |
|  2 | 07:00     | NULL      | banana    |
|  5 | 23:00     | ABC       | cherry    |
+----+-----------+-----------+-----------+

我想只保留每个 ID 具有最新时间戳的记录,这样我就有

+----+-----------+-----------+-----------+-------+
| ID | timestamp | metadata1 | metadata2 | count |
+----+-----------+-----------+-----------+-------+
|  1 | 18:00     | XYZ       | apple     |     3 |
|  2 | 07:00     | NULL      | banana    |     1 |
|  5 | 23:00     | ABC       | cherry    |     1 |
+----+-----------+-----------+-----------+-------+

我努力了:

window = Window.orderBy( [asc('ID'), desc('timestamp')] )
window_count = Window.orderBy( [asc('ID'), desc('timestamp')] ).rowsBetween(-sys.maxsize,sys.maxsize)

columns_metadata = [metadata1, metadata2]

df = df.select(
              *(first(col_name, ignorenulls=True).over(window).alias(col_name) for col_name in columns_metadata),
              count(col('ID')).over(window_count).alias('count')
              )
df = df.withColumn("row_tmp", row_number().over(window)).filter(col('row_tmp') == 1).drop(col('row_tmp'))

这部分基于如何选择每组的第一行?

这没有使用 pyspark.sql.Window.partitionBy,这不会给出所需的输出。

标签: apache-sparkpysparkapache-spark-sql

解决方案


在我发布之后,我读到了你想要的没有按 ID 分区的内容。我只能想到这种方法。

您的数据框:

df = sqlContext.createDataFrame(
  [
     ('1', '09:00', 'ABC', 'apple')
    ,('1', '08:00', '', '')
    ,('1', '18:00', 'XYZ', 'apple')
    ,('2', '07:00', '', 'banana')
    ,('5', '23:00', 'ABC', 'cherry')
  ]
  ,['ID', 'timestamp', 'metadata1', 'metadata2']
)

我们可以在时间戳上按 ID 使用排名和分区:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

w1 = Window().partitionBy(df['ID']).orderBy(df['timestamp']).orderBy(F.desc('timestamp'))
w2 = Window().partitionBy(df['ID'])

df\
  .withColumn("rank", F.rank().over(w1))\
  .withColumn("count", F.count('ID').over(w2))\
  .filter(F.col('rank') == 1)\
  .select('ID', 'timestamp', 'metadata1', 'metadata2', 'count')\
  .show()

+---+---------+---------+---------+-----+
| ID|timestamp|metadata1|metadata2|count|
+---+---------+---------+---------+-----+
|  1|    18:00|      XYZ|    apple|    3|
|  2|    07:00|         |   banana|    1|
|  5|    23:00|      ABC|   cherry|    1|
+---+---------+---------+---------+-----+

推荐阅读