首页 > 解决方案 > 如何在分区中转换行

问题描述

我有一个火花场景。必须对数据框进行分区。结果应该由每个分区一次处理。

List<String> data = Arrays.asList("con_dist_1", "con_dist_2", 
        "con_dist_3", "con_dist_4", "con_dist_5",
        "con_dist_6");
Dataset<Row> codes = sparkSession.createDataset(data, Encoders.STRING());
Dataset<Row> partitioned_codes = codes.repartition(col("codes"));

// I need to paritition it dues to functional requirement
partitioned_codes.foreachPartition(itr -> {
    if (itr.hasNext()) {
        Row inrow = itr.next();
        System.out.println("inrow.length : " + inrow.length());
        System.out.println(inrow.toString());
        List<Object> objs = inrow.getList(0);
    }
});

得到错误

Caused by: java.lang.ClassCastException: java.lang.String cannot be cast to scala.collection.Seq
    at org.apache.spark.sql.Row$class.getSeq(Row.scala:283)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getSeq(rows.scala:166)
    at org.apache.spark.sql.Row$class.getList(Row.scala:291)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getList(rows.scala:166)

问题:这里 如何处理foreachPartitionitr每次迭代都包含一组行,如何使用这些行获取这些行itr

测试1:

inrow.length: 0
[]
inrow.length: 0
[]
2020-03-02 05:22:14,179 [Executor task launch worker for task 615] ERROR org.apache.spark.executor.Executor - Exception in task 110.0 in stage 21.0 (TID 615)
java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.String
    at org.apache.spark.sql.Row$class.getString(Row.scala:255)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:166)

输出 1:

inrow.length: 0
[]
inrow.length: 0
[]
inrow.length: 1
[con_dist_1]
inrow.length: 1
[con_dist_2]
inrow.length: 1
[con_dist_5]
inrow.length: 1
[con_dist_6]
inrow.length: 1
[con_dist_4]
inrow.length: 1
[con_dist_3]

标签: javascalaapache-sparkjava-8apache-spark-sql

解决方案


分区的所有行都在itr. 所以当你打电话时itr.next(),你只会得到第一行。如果你需要打印所有的行,你可以使用一个while循环,或者你可以将迭代器转换为一个类似这样的列表(我怀疑这是你想要的):

partitioned_codes.foreachPartition(itr -> {
    Iterable<Row> rowIt = () -> itr;
    List<String> objs = StreamSupport.stream(rowIt.spliterator(), false)
            .map(row -> row.getString(0))
            .collect(Collectors.toList());

    System.out.println("inrow.length: " + objs.size());
    System.out.println(objs);
});

您发布的示例代码没有为我编译,所以这是我测试的版本:

List<String> data = Arrays.asList("con_dist_1", "con_dist_2", 
        "con_dist_3", "con_dist_4", "con_dist_5",
        "con_dist_6");
StructType struct = new StructType()
        .add(DataTypes.createStructField("codes", DataTypes.StringType, true));
Dataset<Row> codes = sparkSession.createDataFrame(sc.parallelize(data, 2)
                        .map(s -> RowFactory.create(s)), struct);
Dataset<Row> partitioned_codes = codes.repartition(org.apache.spark.sql.functions.col("codes"));

partitioned_codes.foreachPartition(itr -> {
    Iterable<Row> rowIt = () -> itr;
    List<String> objs = StreamSupport.stream(rowIt.spliterator(), false)
            .map(row -> row.getString(0))
            .collect(Collectors.toList());

    System.out.println("inrow.length: " + objs.size());
    System.out.println(objs);
});

推荐阅读