首页 > 解决方案 > IllegalArgumentException:要求失败:rawPredictionCol 向量的长度必须为 2,但在 Apache Spark 中测试模型时得到 3

问题描述

我正在尝试使用 OneR 算法在Apache Spark 3.1.1中创建模型和评估。我有.csv标准化数据的文件(所有值都是double,但有些值非常接近 0)。

我正在阅读MLlib 主要指南OnevsRest,代码与此非常相似:

SparkSession session = SparkSession
                .builder()
                .appName("Spark test")
                .master("local")
                .getOrCreate();

JavaRDD<LabeledPoint> data = loadData(session, "path.csv");

LogisticRegression logisticRegression = new LogisticRegression().setMaxIter(20);
OneVsRest oneR = new OneVsRest().setClassifier(logisticRegression);

BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
MulticlassMetrics metrics = new MulticlassMetrics(data.rdd());
MulticlassClassificationEvaluator multiEvaluator = new MulticlassClassificationEvaluator()
        .setMetricName("accuracy");
        
JavaRDD<LabeledPoint>[] javaRDDS = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingRDD = javaRDDS[0], testRDD = javaRDDS[1];
Dataset<Row> trainingDataset = session.createDataFrame(trainingRDD, LabeledPoint.class);
Dataset<Row> testDataset = session.createDataFrame(testRDD, LabeledPoint.class);

OneVsRestModel oneRModel = oneR.fit(trainingDataset);
Dataset<Row> oneRPredictions = oneRModel.transform(testDataset).select("prediction", "label");

double oneRAcc = evaluator.evaluate(oneRPredictions);
System.out.println("OneR: \r\n");
System.out.println("Accuracy: " + oneRAcc);
System.out.println("--------------------------------");

session.close();

此代码引发异常:

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: rawPredictionCol vectors must have length=2, but got 3
    at scala.Predef$.require(Predef.scala:281)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.$anonfun$getMetrics$1(BinaryClassificationEvaluator.scala:126)
    at scala.runtime.java8.JFunction1$mcVI$sp.apply(JFunction1$mcVI$sp.java:23)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.getMetrics(BinaryClassificationEvaluator.scala:126)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:100)
    at Classification.main(Classification.java:64)

为什么这段代码不起作用?我认为问题出在.select("prediction", "label")因为我不知道我在Dataset<Row>之后有什么transform,但vectors must have length of 2 not 3很奇怪。我正在尝试使用 3 个类进行多类分类。

编辑

我正在使用BinaryClassificationEvaluator evaluator而不是MulticlassClassificationEvaluator multiEvaluator错误地使用。现在错误消息是有道理的。

标签: javaapache-sparkapache-spark-mlapache-spark-dataset

解决方案


推荐阅读