首页 > 解决方案 > 仅将每行的非空列收集到数组中

问题描述

困难在于我试图尽可能地避免 UDF。

我有一个数据集“wordsDS”,其中包含许多空值:

+------+------+------+------+
|word_0|word_1|word_2|word_3|
+------+------+------+------+
|     a|     b|  null|     d|
|  null|     f|     m|  null|
|  null|  null|     d|  null|
+--------------+------+-----|

我需要将每一行的所有列收集到 array。我事先不知道列数,所以我使用 columns() 方法。

groupedQueries = wordsDS.withColumn("collected",
      functions.array(Arrays.stream(wordsDS.columns())
               .map(functions::col).toArray(Column[]::new)));;

但是这种方法会产生空元素

+--------------------+
|           collected|
+--------------------+
|           [a, b,,d]|
|          [, f, m,,]|
|            [,, d,,]|
+--------------------+

相反,我需要以下结果:

+--------------------+
|           collected|
+--------------------+
|           [a, b, d]|
|              [f, m]|
|                 [d]|
+--------------------+

所以基本上,我需要收集每一行的所有列,以符合以下要求:

  1. 结果数组不包含空元素。
  2. 不知道前面的列数。

我也考虑过过滤数据集的“收集”列以获取空值的方法,但除了 UDF 之外无法提供任何其他内容。我试图避免 UDF 以免影响性能,如果有人可以建议一种方法来过滤数据集的“收集”列以获取尽可能少的开销的空值,那将非常有帮助。

标签: apache-sparkapache-spark-sql

解决方案


您可以使用array("*")将所有元素放入 1 个数组中,然后使用array_except(需要 Spark 2.4+)过滤掉空值:

df
  .select(array_except(array("*"),array(lit(null))).as("collected"))
  .show()

+---------+
|collected|
+---------+
|[a, b, d]|
|   [f, m]|
|      [d]|
+---------+

推荐阅读