python - 预测类别及其相应的概率
问题描述
我已经使用 maxvoting(决策树、随机森林、逻辑回归)分类器构建了一个机器学习模型。我的输入为
{“工资”:50000,“当前贷款”:15000,“信用评分”:616,“申请贷款”:25000 }
当我将此数据传递给我的模型时。它给出的预测为
{“状态”:批准}
但我需要像这样检索响应
{“状态”:批准,“准确性”:0.87}
任何帮助将非常感激
解决方案
看起来您可能正在使用 sklearn 的VotingClassifier
. 安装好分类器后,您可以通过属性查看与每个类关联的概率predict_proba
。请注意,这不是准确度,而是每个类别的相关概率。因此,如果您希望测试样本属于 class 的概率,则必须在相应列上n
索引输出。y_pred_prob
这是使用 sklearn 的 iris 数据集的示例:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
clf1 = LogisticRegression(multi_class='multinomial', random_state=1)
clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
clf3 = GaussianNB()
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
eclf2 = VotingClassifier(estimators=[
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='soft')
eclf2 = eclf2.fit(X_train, y_train)
我们可以得到与第一类相关的概率,例如:
eclf2.predict_proba(X_test)[:,0].round(2)
array([0.99, 0. , 0. , 0. , 0. , 0. , 0.01, 0.01, 0. , 0. , 0. ,
0.99, 0. , 0.99, 0.99, 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0.01, 0.98, 0. , 1. , 0.99, 0. , 0. , 0. , 0.99, 0.98,
0. , 0.99, 0. , 0.01, 0.99])
最后,要获得您所描述的输出,您可以使用 , 返回的结果predict
来索引 2D 概率数组,如下所示:
import pandas as pd
y_pred = eclf2.predict(X_test)
y_pred_prob = eclf2.predict_proba(X_test).round(2)
associated_prob = y_pred_prob[np.arange(len(y_test)), y_pred]
pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob})
class Accuracy
0 0 0.99
1 2 0.84
2 2 1.00
3 1 0.95
4 2 0.99
5 2 0.91
6 1 0.98
7 1 0.98
8 1 0.93
或者,如果您更喜欢将输出作为字典:
pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob}).to_dict(orient='index')
{0: {'class': 0, 'Accuracy': 0.99},
1: {'class': 2, 'Accuracy': 0.84},
2: {'class': 2, 'Accuracy': 1.0},
3: {'class': 1, 'Accuracy': 0.95},
4: {'class': 2, 'Accuracy': 0.99},
推荐阅读
- android - 注销后如何使用 Retrofit 处理“204 no content”响应?
- c++ - 将unique_ptr / auto_ptr隐式转换为基本类型不起作用?
- azure - Azure Synapse Analytics 参数化笔记本不起作用
- angular - 被 CORS 阻止:“Access-Control-Allow-Origin”标头包含多个值“*、*”、
- laravel - 在 laravel 集体中提交表单时显示加载器
- fitnesse - FitNesse 使用全局变量加载包含的页面
- syntax - PHP 解析错误:语法错误,第 7 行的意外“定义”(T_STRING)wp-config.php
- python - 如何标准化 matplotlib 直方图中的概率分布值?
- mysql - MySQL View 在单列中显示多列中的行
- c++ - vscode 上的 C++:调试得到错误,但运行没有调试的代码仍在运行