首页 > 解决方案 > 预测类别及其相应的概率

问题描述

我已经使用 maxvoting(决策树、随机森林、逻辑回归)分类器构建了一个机器学习模型。我的输入为

{“工资”:50000,“当前贷款”:15000,“信用评分”:616,“申请贷款”:25000 }

当我将此数据传递给我的模型时。它给出的预测为

{“状态”:批准}

但我需要像这样检索响应

{“状态”:批准,“准确性”:0.87}

任何帮助将非常感激

标签: pythonmachine-learningscikit-learndecision-treesklearn-pandas

解决方案


看起来您可能正在使用 sklearn 的VotingClassifier. 安装好分类器后,您可以通过属性查看与每个类关联的概率predict_proba。请注意,这不是准确度,而是每个类别的相关概率。因此,如果您希望测试样本属于 class 的概率,则必须在相应列上n索引输出。y_pred_prob这是使用 sklearn 的 iris 数据集的示例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, VotingClassifier

from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB

clf1 = LogisticRegression(multi_class='multinomial', random_state=1)
clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
clf3 = GaussianNB()

X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y)

eclf2 = VotingClassifier(estimators=[
        ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
        voting='soft')

eclf2 = eclf2.fit(X_train, y_train)

我们可以得到与第一类相关的概率,例如:

eclf2.predict_proba(X_test)[:,0].round(2)

array([0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  , 0.  , 0.  ,
       0.99, 0.  , 0.99, 0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.01, 0.98, 0.  , 1.  , 0.99, 0.  , 0.  , 0.  , 0.99, 0.98,
       0.  , 0.99, 0.  , 0.01, 0.99])

最后,要获得您所描述的输出,您可以使用 , 返回的结果predict来索引 2D 概率数组,如下所示:

import pandas as pd

y_pred = eclf2.predict(X_test)
y_pred_prob = eclf2.predict_proba(X_test).round(2)
associated_prob = y_pred_prob[np.arange(len(y_test)), y_pred]
pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob})

    class  Accuracy
0       0      0.99
1       2      0.84
2       2      1.00
3       1      0.95
4       2      0.99
5       2      0.91
6       1      0.98
7       1      0.98
8       1      0.93

或者,如果您更喜欢将输出作为字典:

pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob}).to_dict(orient='index')

 {0: {'class': 0, 'Accuracy': 0.99},
  1: {'class': 2, 'Accuracy': 0.84},
  2: {'class': 2, 'Accuracy': 1.0},
  3: {'class': 1, 'Accuracy': 0.95},
  4: {'class': 2, 'Accuracy': 0.99},

推荐阅读