首页 > 解决方案 > Spark MLlib 仅在阈值大于值时进行预测

问题描述

我有一个多类分类(38 类)问题,并在 Spark ML 中实现了一个管道来解决它。这就是我生成模型的方式。

val nb = new NaiveBayes()
  .setLabelCol("id")
  .setFeaturesCol("features")
  .setThresholds(Seq(1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25).toArray)

val pipeline = new Pipeline()
  .setStages(Array(stages, assembler, nb))

val model = pipeline.fit(trainingSet)

我希望我的模型只有在其置信度(概率)大于 0.8% 时才能预测一个类。

我在 Spark 文档中进行了很多搜索,以更好地理解阈值参数的含义,但我发现的唯一相关信息是这个:

多类分类中的阈值来调整预测每个类的概率。数组的长度必须等于类的数量,除了最多一个值可能是 0 之外,值 > 0。预测具有最大值 p/t 的类,其中 p 是该类的原始概率,t 是该类的临界点。

这就是为什么我的阈值是 1.25。

问题是,无论我为阈值插入的值如何,它都不会影响我的预测。

你知道是否有可能只预测置信度(概率)大于特定阈值的类?

另一种方法是仅选择概率大于该阈值的预测,但我希望这可以使用框架来实现。

谢谢。

标签: apache-sparkmachine-learningapache-spark-mllib

解决方案


推荐阅读