首页 > 解决方案 > 运行总和/累积总和与地板和天花板 Py Spark

问题描述

我是新来的火花,我正在尝试计算一个以 0 为底,以 8 为上限的窗口运行总和

下面给出了一个玩具示例(请注意,实际数据更接近数百万行):

import pyspark.sql.functions as F
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType

pdf = pd.DataFrame({'ids':    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
                    'day':    [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], 
                    'counts': [-3, 3, -6, 3, 3, 6, -3, -6, 3, 3, 3, -3]})
sdf = spark.createDataFrame(pdf)
sdf = sdf.orderBy(sdf.ids,sdf.day)

这将创建表

+----+---+-------+
|aIds|day|eCounts|
+----+---+-------+
|   1|  1|     -3|
|   1|  2|      3|
|   1|  3|     -6|
|   1|  4|      3|
|   2|  1|      3|
|   2|  2|      6|
|   2|  3|     -3|
|   2|  4|     -6|
|   3|  1|      3|
|   3|  2|      3|
|   3|  3|      3|
|   3|  4|     -3|
+----+---+-------+

下面是运行求和结果的示例,以及预期的输出 runSumCap

+----+---+-------+------+---------+
|aIds|day|eCounts|runSum|runSumCap|
+----+---+-------+------+---------+
|   1|  1|     -3|    -3|        0| <-- reset to 0
|   1|  2|      3|     0|        3|
|   1|  3|     -6|    -6|        0| <-- reset to 0
|   1|  4|      3|    -3|        3|
|   2|  1|      3|     3|        3|
|   2|  2|      6|     9|        8| <-- reset to 8
|   2|  3|     -3|     6|        5| 
|   2|  4|     -6|     0|        0| <-- reset to 0
|   3|  1|      3|     3|        3|
|   3|  2|      3|     6|        6|
|   3|  3|      3|     9|        8| <-- reset to 8
|   3|  4|     -3|     6|        5|
+----+---+-------+------+---------+

我知道我可以将运行总和计算为

partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)`
sdf1 = sdf.withColumn('runSum',F.sum(sdf.eCounts).over(partition))
sdf1.orderBy('aIds','day').show()

为了达到预期,我尝试查看@pandas_udf 来修改总和:

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def runSumCap(counts):
    #counts columns is passed as a pandas series
    floor = 0
    cap = 8
    runSum = 0
    runSumList = []
    for count in counts.tolist():
      runSum = runSum + count
      if(runSum > cap):
        runSum = 8
      elif(runSum < floor ):
        runSum = 0
      runSumList += [runSum]
    return pd.Series(runSumList)


partition = Window.partitionBy('aIds').orderBy('aIds','day').rowsBetween(Window.unboundedPreceding, Window.currentRow)
sdf1 = sdf.withColumn('runSum',runSumCap(sdf['counts']).over(partition))

然而,这不起作用,而且它似乎不是最有效的方法。我怎样才能使这项工作?有没有办法让它保持平行,或者我必须去熊猫数据框

编辑: 对现有列进行了一些说明,以对数据集进行排序,并对我想要实现的目标进行了一些深入的了解

EDIT2: @DrChess 提供的答案几乎产生了正确的结果,但由于某种原因,该系列与正确的日期不匹配:

+----+---+-------+------+
|aIds|day|eCounts|runSum|
+----+---+-------+------+
|   1|  1|     -3|     0|
|   1|  2|      3|     0|
|   1|  3|     -6|     3|
|   1|  4|      3|     3|
|   2|  1|      3|     3|
|   2|  2|      6|     8|
|   2|  3|     -3|     0|
|   2|  4|     -6|     5|
|   3|  1|      3|     6|
|   3|  2|      3|     3|
|   3|  3|      3|     8|
|   3|  4|     -3|     5|
+----+---+-------+------+

标签: apache-sparkpysparkpyspark-sqlpyspark-dataframes

解决方案


不幸的是,带有pandas_udfof 类型的窗口函数GROUPED_AGG不适用于有界窗口函数 ( .rowsBetween(Window.unboundedPreceding, Window.currentRow))。它目前仅适用于无界窗口,即.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing). 此外,输入是 apandas.Series但输出应该是所提供类型的常量。因此,您将无法实现部分聚合。

相反,您可以使用GROUPED_MAP pandas_udfwhich 与df.groupBy().apply(). 这里有一些代码:

@pandas_udf('ids integer, day integer, counts integer, runSum integer', PandasUDFType.GROUPED_MAP)
def runSumCap(pdf):
    def _apply_on_series(counts):
        floor = 0
        cap = 8
        runSum = 0
        runSumList = []
        for count in counts.tolist():
            runSum = runSum + count
            if(runSum > cap):
                runSum = 8
            elif(runSum < floor ):
                runSum = 0
            runSumList += [runSum]
        return pd.Series(runSumList)
    pdf.sort_values(by=['day'], inplace=True)
    pdf['runSum'] = _apply_on_series(pdf['counts'])
    return pdf


sdf1 = sdf.groupBy('ids').apply(runSumCap)

推荐阅读