首页 > 解决方案 > 在 Java 中创建 UDF 以将一个数据框列映射到另一列

问题描述

我在 spark-shell 中编写了 Scala 代码,以将数据帧的一列映射到另一列。我现在正在尝试将其转换为 Java,但在使用我定义的 UDF 时遇到了困难。

我正在使用这个数据框:

+------+-----+-----+
|acctId|vehId|count|
+------+-----+-----+
|     1|  777|    3|
|     2|  777|    1|
|     1|  666|    1|
|     1|  999|    3|
|     1|  888|    2|
|     3|  777|    4|
|     2|  999|    1|
|     3|  888|    2|
|     2|  888|    3|
+------+-----+-----+

并将其转换为:

+------+----------------------------------------+
|acctId|vehIdToCount                            |
+------+----------------------------------------+
|1     |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3     |[777 -> 4, 888 -> 2]                    |
|2     |[777 -> 1, 999 -> 1, 888 -> 3]          |
+------+----------------------------------------+

我正在通过这些命令执行此操作。首先,我的 UDF 将行值列表从一列映射到第二列:

val listToMap = udf((input: Seq[Row]) => input.map(row => (row.getAs[Long](0), row.getAs[Long](1))).toMap)

我通过双重 groupBy/聚合来做到这一点:

val resultDF = testData.groupBy("acctId", "vehId")
     .agg(count("acctId").cast("long").as("count"))
     .groupBy("acctId")
     .agg(collect_list(struct("vehId", "count")) as ("vehIdToCount"))
     .withColumn("vehIdToCount", listToMap($"map"))

我的问题是尝试用 Java 编写 listToMap UDF。我对 Scala 和 Java 都很陌生,所以我可能只是遗漏了一些东西。

我希望我可以做一些简单的事情:

UserDefinedFunction listToMap = udf(
        (Seq<Dataset<Row>> input) -> input.map(r -> (r.get(“vehicleId”), r.get(“count”)));
);

但是,即使在相当广泛地浏览了文档之后,我也无法确定获取这些列中的每一列的有效方法。我也尝试过只做一个 SELECT 但这也不起作用。

任何帮助深表感谢。供您参考,这是我在 spark-shell 中生成测试数据的方式:

val testData = Seq(
    (1, 999),
    (1, 999),
    (2, 999),
    (1, 888),
    (2, 888),
    (3, 888),
    (2, 888),
    (2, 888),
    (1, 888),
    (1, 777),
    (1, 666),
    (3, 888),
    (1, 777),
    (3, 777),
    (2, 777),
    (3, 777),
    (3, 777),
    (1, 999),
    (3, 777),
    (1, 777)
).toDF("acctId", "vehId”)

标签: javaapache-sparkapache-spark-sqluser-defined-functions

解决方案


我无法帮助您编写 UDF,但我可以向您展示如何使用 Spark 的内置map_from_entries函数来避免它。UDF 应该始终是最后的手段,既要保持代码库简单,又因为 Spark 无法优化它们。下面的例子是在 Scala 中,但翻译起来应该很简单:

scala> val testData = Seq(
     |     (1, 999),
     |     (1, 999),
     |     (2, 999),
     |     (1, 888),
     |     (2, 888),
     |     (3, 888),
     |     (2, 888),
     |     (2, 888),
     |     (1, 888),
     |     (1, 777),
     |     (1, 666),
     |     (3, 888),
     |     (1, 777),
     |     (3, 777),
     |     (2, 777),
     |     (3, 777),
     |     (3, 777),
     |     (1, 999),
     |     (3, 777),
     |     (1, 777)
     | ).toDF("acctId", "vehId")
testData: org.apache.spark.sql.DataFrame = [acctId: int, vehId: int]

scala> 

scala> val withMap = testData.groupBy('acctId, 'vehId).
     | count.
     | select('acctId, struct('vehId, 'count).as("entries")).
     | groupBy('acctId).
     | agg(map_from_entries(collect_list('entries)).as("myMap"))
withMap: org.apache.spark.sql.DataFrame = [acctId: int, myMap: map<int,bigint>]

scala> 

scala> withMap.show(false)
+------+----------------------------------------+
|acctId|myMap                                   |
+------+----------------------------------------+
|1     |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
|3     |[777 -> 4, 888 -> 2]                    |
|2     |[777 -> 1, 999 -> 1, 888 -> 3]          |
+------+----------------------------------------+

推荐阅读