首页 > 解决方案 > 如何使用新列scala在内存转换中应用窗口函数

问题描述

我有一个数据框,我想将其转换为以下输出,其中每一行 start_duration 和 end_duration 将由前一行 start_duration 和 end_duration 产生,请让我知道如何使用 scala 在 spark 中实现它。

下面是计算 start_duration 和 end_duration 的公式:

start_duration = max(previous end_duration + 1, current date); 
end_duration = min(presciption_end date, start_duration + duration – 1)

下面是我的输入数据框:

+--------

--------+-----------+---------+-----------+----------------+----------+--------+----------+----------+
|prescription_uid|patient_uid|ndc      |label      |dispensation_uid|date      |duration|start_date|end_date  |
+----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+
|0               |0          |16714-128|sinvastatin|0               |2015-06-10|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|1               |2015-07-15|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|2               |2015-08-01|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|3               |2015-10-01|30      |2015-06-01|2015-12-01|
+----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+

预期输出数据帧:

EXPECTED RESULT:
    +--------
    
        --------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+
        |prescription_uid|patient_uid|ndc      |label      |dispensation_uid|date      |duration|start_date|end_date  |first_start_duration|first_end_duration|start_duration|end_duration|
        +----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+
        |0               |0          |16714-128|sinvastatin|0               |2015-06-10|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-06-10    |2015-07-09  |
        |0               |0          |16714-128|sinvastatin|1               |2015-07-15|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-07-15    |2015-08-13  |
        |0               |0          |16714-128|sinvastatin|2               |2015-08-01|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-08-14    |2015-09-13  |
        |0               |0          |16714-128|sinvastatin|3               |2015-10-01|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-10-01    |2015-10-30  |
        +----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+
    
Code tried : 
val windowByPatient = Window.partitionBy($"patient_uid").orderBy($"date")
    val windowByPatientBeforeCurrentRow = windowByPatient.rowsBetween(Window.unboundedPreceding, -1)
    joinedPrDF = joinedPrDF
      .withColumn("first_start_duration", firstStartDuration(first($"date").over(windowByPatient), $"start_date"))
      .withColumn("first_end_duration", firstEndDuration($"first_start_duration", $"end_date", $"duration"))
      .withColumn("start_duration", when(count("*").over(windowByPatient) === 1, $"first_start_duration")
        .otherwise(startDurationCalc($"first_start_duration", $"date", $"start_date", coalesce(sum($"duration").over(windowByPatientBeforeCurrentRow), lit("0")))))
      .withColumn("end_duration", when(count("*").over(windowByPatient) === 1, $"first_end_duration")
        .otherwise(endDurationCalc($"end_date", $"start_duration", $"duration")))

UDF:

val startDurationCalc = udf( (firstStrtDur:java.sql.Date, currentDsDate:java.sql.Date,
                                      prsStartDate:java.sql.Date,duration:Int) => {
      println("==="+firstStrtDur+"==="+currentDsDate +"==="+prsStartDate +"==="+duration )

        var startDate = java.sql.Date.valueOf(firstStrtDur.toLocalDate.plusDays(duration))
        if (startDate.after(currentDsDate)) {
          startDate
        } else {
          currentDsDate
        }
    } : java.sql.Date)

    val endDurationCalc = udf( (prsEndDate:java.sql.Date, startDuration:java.sql.Date,duration:Int) => {

      println("endDateCalcContRow ==="+prsEndDate+"==="+startDuration +"==="+duration )

      val currEndDate = java.sql.Date.valueOf(startDuration.toLocalDate.plusDays(duration-1))
      if (currEndDate.before(prsEndDate)) {
        currEndDate
      } else {
        prsEndDate
      }

    } : java.sql.Date)

标签: scalaapache-sparkapache-spark-sql

解决方案


您不应该期望窗口函数对数据框中不存在的数据进行计算,而是在执行期间计算(您将其称为“内存行中”)。这不可能。

你可以尝试不同的方法。start_duration根据duration(您可以考虑可能的差距)计算第一个表格的每个表格。

val windowByPatient = Window.partitionBy("patient_uid").orderBy("date")
val windowByPatientBeforeCurrentRow = windowByPatient.rowsBetween(Window.unboundedPreceding, -1)

data
  .withColumn("previous_date", lag("date", 1).over(windowByPatient))
  .withColumn("diff_from_prev", datediff(col("date"), coalesce(col("previous_date"), col("date"))))
  .withColumn("diff_with_duration", when(col("diff_from_prev") >= lag("duration", 1, 0).over(windowByPatient), col("diff_from_prev")).otherwise(col("duration")))
  .withColumn("first_date_by_patient", first("date").over(windowByPatient))
  .withColumn("duration_from_first_with_gaps", col("diff_with_duration") + coalesce(sum("diff_from_prev").over(windowByPatientBeforeCurrentRow), lit("0")))
  .withColumn("start_duration", expr("date_add(first_date_by_patient, duration_from_first_with_gaps)"))
  .withColumn("end_duration", expr("date_add(start_duration, duration - 1)"))
  .select((data.columns ++ Seq("start_duration", "end_duration")).map(col): _*)
  .show()

date_add包裹在 中expr,因为它Int作为第二个参数,但可以与 sql 上下文中的列一起使用。


推荐阅读