首页 > 解决方案 > 添加到管道:多标签分类器预测的修改器

问题描述

我创建了一个管道,它在最后使用随机森林分类器进行多标签预测。

现在有时所有类的预测都是 0。在这种情况下,我想修改预测,使第一个标签默认为 1。

例如,

在 sklearn 管道中实现这一目标的方法是什么。是否有某种我可以扩展和实现然后添加到管道中的 Prediction-Modifier 类?

所以我想做的就是在完成后修改随机森林分类器的预测。我可以在代码中轻松地做到这一点,但我不知道如何在管道中做到这一点,例如我可以在网格搜索中做到这一点。

标签: scikit-learn

解决方案


我采取了如下扩展随机森林分类器的方法:

class ModifiedRandomForestClassifier(RandomForestClassifier):
  def predict(self, X):
    y_pred = super().predict(X)
    # default to first class if multi-label prediction is (0,0,...,0)
    y_pred[y_pred.sum(axis=1) == 0, 0] = 1
    return y_pred

推荐阅读