首页 > 解决方案 > 如何仅获得概率大于 x 的预测

问题描述

我使用随机森林将文本分类到某些类别。当我使用我的测试数据时,我得到了 0.98 的准确度。但是使用另一组数据,整体准确度下降到 0.7。我认为,大多数行仍然具有很高的准确性。

所以现在我只想显示具有高置信度的预测类别。random-forrest 给了我一列“概率”,它是一个概率数组。如何获得所选预测的实际概率?

val randomForrest = new RandomForestClassifier()
      .setLabelCol(labelIndexer.getOutputCol)
      .setFeaturesCol(vectorAssembler.getOutputCol)
      .setProbabilityCol("probability")
      .setSeed(123)
      .setPredictionCol("prediction")

标签: random-forestapache-spark-mllib

解决方案


我最终想出了以下 udf 以获得最佳预测及其概率。如果有更方便的方法,请评论。

def getBestPrediction = udf((
  rawPrediction: org.apache.spark.ml.linalg.Vector, probability: org.apache.spark.ml.linalg.Vector) => {
  val bestPrediction = probability.argmax
  val bestProbability = probability(bestPrediction)     
  (bestPrediction, bestProbability)
})

推荐阅读