首页 > 解决方案 > 通过一列中的唯一值随机拆分 DataFrame

问题描述

我有一个 pyspark DataFrame,如下所示:

+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val11  | val21  |   0       |
| val12  | val22  |   1       |
| val13  | val23  |   2       |
| val14  | val24  |   0       |
| val15  | val25  |   1       |
| val16  | val26  |   1       |
+--------+--------+-----------+    

每行有一个groupId,多行可以有相同的groupId

我想将这些数据随机分成两个数据集。但是所有具有特定属性的数据都groupId必须在其中一个拆分中。

这意味着 if d1.groupId = d2.groupId、 thend1d2都在同一个拆分中。

例如:

# Split 1:

+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val11  | val21  |   0       |
| val13  | val23  |   2       |
| val14  | val24  |   0       |
+--------+--------+-----------+

# Split 2:
+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val12  | val22  |   1       |
| val15  | val25  |   1       |
| val16  | val26  |   1       |
+--------+--------+-----------+

在 PySpark 上做这件事的好方法是什么?我可以randomSplit以某种方式使用该方法吗?

标签: apache-sparkpysparkapache-spark-sql

解决方案


您可以使用randomSplit仅拆分不同groupId的 s,然后使用结果拆分源 DataFrame 使用join.

例如:

split1, split2 = df.select("groupId").distinct().randomSplit(weights=[0.5, 0.5], seed=0)
split1.show()
#+-------+
#|groupId|
#+-------+
#|      1|
#+-------+

split2.show()
#+-------+
#|groupId|
#+-------+
#|      0|
#|      2|
#+-------+

现在将它们加入到原始 DataFrame 中:

df1 = df.join(split1, on="groupId", how="inner")
df2 = df.join(split2, on="groupId", how="inner")

df1.show()
3+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#|      1|val12|val22|
#|      1|val15|val25|
#|      1|val16|val26|
#+-------+-----+-----+

df2.show()
#+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#|      0|val11|val21|
#|      0|val14|val24|
#|      2|val13|val23|
#+-------+-----+-----+

推荐阅读