java - IllegalArgumentException:要求失败:rawPredictionCol 向量的长度必须为 2,但在 Apache Spark 中测试模型时得到 3
问题描述
我正在尝试使用 OneR 算法在Apache Spark 3.1.1中创建模型和评估。我有.csv
标准化数据的文件(所有值都是double
,但有些值非常接近 0)。
我正在阅读MLlib 主要指南OnevsRest,代码与此非常相似:
SparkSession session = SparkSession
.builder()
.appName("Spark test")
.master("local")
.getOrCreate();
JavaRDD<LabeledPoint> data = loadData(session, "path.csv");
LogisticRegression logisticRegression = new LogisticRegression().setMaxIter(20);
OneVsRest oneR = new OneVsRest().setClassifier(logisticRegression);
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
MulticlassMetrics metrics = new MulticlassMetrics(data.rdd());
MulticlassClassificationEvaluator multiEvaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
JavaRDD<LabeledPoint>[] javaRDDS = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingRDD = javaRDDS[0], testRDD = javaRDDS[1];
Dataset<Row> trainingDataset = session.createDataFrame(trainingRDD, LabeledPoint.class);
Dataset<Row> testDataset = session.createDataFrame(testRDD, LabeledPoint.class);
OneVsRestModel oneRModel = oneR.fit(trainingDataset);
Dataset<Row> oneRPredictions = oneRModel.transform(testDataset).select("prediction", "label");
double oneRAcc = evaluator.evaluate(oneRPredictions);
System.out.println("OneR: \r\n");
System.out.println("Accuracy: " + oneRAcc);
System.out.println("--------------------------------");
session.close();
此代码引发异常:
Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: rawPredictionCol vectors must have length=2, but got 3
at scala.Predef$.require(Predef.scala:281)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.$anonfun$getMetrics$1(BinaryClassificationEvaluator.scala:126)
at scala.runtime.java8.JFunction1$mcVI$sp.apply(JFunction1$mcVI$sp.java:23)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.getMetrics(BinaryClassificationEvaluator.scala:126)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:100)
at Classification.main(Classification.java:64)
为什么这段代码不起作用?我认为问题出在.select("prediction", "label")
因为我不知道我在Dataset<Row>
之后有什么transform
,但vectors must have length of 2 not 3
很奇怪。我正在尝试使用 3 个类进行多类分类。
编辑
我正在使用BinaryClassificationEvaluator evaluator
而不是MulticlassClassificationEvaluator multiEvaluator
错误地使用。现在错误消息是有道理的。
解决方案
推荐阅读
- decompiling - 任何人都可以识别此代码片段的编程语言吗?
- r - 使用前 50 行计算 R 中的运行平均值
- android-studio - 设置为不将 gmail 发送到特定的 gmail
- amazon-web-services - HashiCorp Vault 创建用户以登录 AWS 管理控制台
- javascript - 从 API 中提取数据
- java - Spring-Boot WebFlux getFormData fails to parse real x-www-form-urlencoded data?
- node.js - 使用 socks5 时的 Request-Promise 库错误:- RequestError:错误:getaddrinfo ENOTFOUND
- c++ - C2011 'class' 类型重新定义
- events - wxpython 4.01 / Python 3.7 中事件函数的值
- python - 当 \d 什么都不返回时,如何告诉 re 返回所有数字?