首页 > 解决方案 > GBTClassifier 如何处理二进制分类的不平衡数据?

问题描述

我想用 GBTClassifier 对不平衡的数据集执行二进制分类。我没有从spark 文档中看到任何允许这样做的选项。

有人知道如何通过指定我们的数据不平衡这一事实来使用 GBTClassifier 吗?

谢谢

注意:我使用的是火花 2.3.2

标签: apache-sparkapache-spark-ml

解决方案


这是我天真的解决方案:随机下采样多数类。此解决方案的缺点是信息丢失,并且不适用于小型数据集。

val resampledTrainDF = {

    val positiveLabel = "1"
    val trainDF_positives = trainDF.where(F.col(label) === positiveLabel)
    val trainDF_negatives = trainDF.where(F.col(label) =!= positiveLabel)

    val withReplacement = trainDF_positives.count >= trainDF_negatives.count

    if (withReplacement) {
        // downsampling positives
        val sampSize = math.round(  (1.0 * trainDF_negatives.count / trainDF_positives.count) * 1000) / 1000.0
        println("Downsampling Positives by " + (1 - sampSize)*100 + " %")
        trainDF_positives.sample(false, sampSize).union(trainDF_negatives)
    } else { 
        //downsampling negatives
        val sampSize = math.round(  (1.0 * trainDF_positives.count / trainDF_negatives.count) * 1000) / 1000.0
        println("Downsampling Negatives by " + (1 - sampSize)*100 +  "%")
        trainDF_negatives.sample(false, sampSize).union(trainDF_positives)
    }

}

推荐阅读