首页 > 解决方案 > 从 spark 函数调用转换为 SQL

问题描述

我有一个具有以下架构的数据集。

root
 |-- acct_id: long (nullable = true)
 |-- firm_bnkg_id: integer (nullable = true)
 |-- tagged: long (nullable = true)
 |-- transactions: array (nullable = false)
 |    |-- element: struct (containsNull = true)
 |    |    |-- mo_yr_buckt: string (nullable = false)
 |    |    |-- acct_id: long (nullable = false)
 |    |    |-- eff_dt: date (nullable = true)
 |    |    |-- extn_txn_cd: string (nullable = true)
 |    |    |-- mntr_txn_am: double (nullable = true)
 |    |    |-- cr_dr_in: string (nullable = true)
 |    |    |-- txn_desc_tx: string (nullable = true)
 |    |    |-- txn_auth_dt: date (nullable = false)
 |    |    |-- txn_auth_ts: string (nullable = false)
 |    |    |-- tagged: long (nullable = true)
 |    |    |-- firm_bnkg_id: integer (nullable = false)
 |    |    |-- txn_pst_sq_nb: string (nullable = false)
 |    |    |-- pst_dt: integer (nullable = false)
 |-- prty_ol_prfl_id: long (nullable = true)
 |-- prod_cd: string (nullable = true)
 |-- acct_type_cd: string (nullable = true)
 |-- acct_state_cd: string (nullable = true)

现在我想将当前代码更改为 SQL 语句。当前的代码是这样的:

val result = ds.select(col("*"), explode(col("transactions")).as("txn"))
  .where("IsValidUDF(txn) = TRUE").groupBy("prty_ol_prfl_id")
  .agg(collect_list("txn").as("transactions"))

这会产生以下架构:

root
 |-- acct_id: long (nullable = true)
 |-- firm_bnkg_id: integer (nullable = true)
 |-- tagged: long (nullable = true)
 |-- transactions: array (nullable = false)
 |    |-- element: struct (containsNull = true)
 |    |    |-- mo_yr_buckt: string (nullable = false)
 |    |    |-- acct_id: long (nullable = false)
 |    |    |-- eff_dt: date (nullable = true)
 |    |    |-- extn_txn_cd: string (nullable = true)
 |    |    |-- mntr_txn_am: double (nullable = true)
 |    |    |-- cr_dr_in: string (nullable = true)
 |    |    |-- txn_desc_tx: string (nullable = true)
 |    |    |-- txn_auth_dt: date (nullable = false)
 |    |    |-- txn_auth_ts: string (nullable = false)
 |    |    |-- tagged: long (nullable = true)
 |    |    |-- firm_bnkg_id: integer (nullable = false)
 |    |    |-- txn_pst_sq_nb: string (nullable = false)
 |    |    |-- pst_dt: integer (nullable = false)
 |-- prty_ol_prfl_id: long (nullable = true)
 |-- prod_cd: string (nullable = true)
 |-- acct_type_cd: string (nullable = true)
 |-- acct_state_cd: string (nullable = true)

IsValidUDF 仅检查标记为某些值的列。

任何帮助,将不胜感激。谢谢

标签: sqlscalaapache-sparkapache-spark-sql

解决方案


您的代码到 spark sql 语句的翻译是:

val new_df = spark.sql("""
    WITH temp AS(
        SELECT *, explode(transactions) AS txn FROM df 
    )
    SELECT first(id) id, collect_list(txn) AS TRANSACTIONS FROM temp WHERE IsValidUDF(txn) = TRUE GROUP BY id 
""")

(只需替换为您希望在结果数据框中拥有的每一列first(id)first(.)

事先确保您的udf 已注册

spark.udf.register("IsValidUDF", is_valid_udf)

这是带有玩具示例的完整代码:

import org.apache.spark.sql.Row

// Toy example
val df = Seq((0, List(66,1) ),(1, List(98, 2)),(2, List(90))).toDF("id", "transactions")
df.createOrReplaceTempView("df")
val is_valid_udf = udf((r: Int) => r > 50)

// register udf
spark.udf.register("IsValidUDF", is_valid_udf)
// query
val new_df = spark.sql("""
    WITH temp AS(
        SELECT *, explode(transactions) AS txn FROM df 
    )
    SELECT first(id) id, collect_list(txn) AS TRANSACTIONS FROM temp WHERE IsValidUDF(txn) = TRUE GROUP BY id 
""")

输出:

+---+------------+
| id|TRANSACTIONS|
+---+------------+
|  1|        [98]|
|  2|        [90]|
|  0|        [66]|
+---+------------+

这是删除交易> 50的原始数据框。


推荐阅读