python - 如何对没有 predict_proba 或 decision_function 的模型使用 CalibratedClassifierCV
问题描述
我正在尝试使用CalibratedClassifierCV()
以创建更好的拟合校准曲线来校准我的模型输出。据我了解,对于基于树的模型、神经网络,必须使用这种方法校准输出以获得最佳性能。但是,当我尝试这样做时,它会引发错误。
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import RandomizedSearchCV
pipe_dtr = Pipeline(steps=[('preprocessor', preprocessor),
('clf', DecisionTreeRegressor(random_state=62))])
params_dtr = {
'clf__max_depth' : np.arange(1,100,5),
'clf__min_samples_leaf' : [0.01, 0.1, 1]
}
gs_dtr = RandomizedSearchCV(estimator=pipe_dtr,
param_distributions=params_dtr,
n_iter=25,
scoring='roc_auc',
cv=5)
gs_dtr.fit(X_train, y_train)
calib_pipe_dtr = Pipeline(steps=[('preprocessor', preprocessor),
('calibrator', CalibratedClassifierCV(gs_dtr.best_estimator_, cv='prefit'))])
calib_pipe_dtr.fit(X_train,y_train)
这引发了以下错误
RuntimeError:分类器没有 decision_function 或 predict_proba 方法。
我该如何解决这个问题..请发表意见。谢谢
解决方案
回归模型应该用于 CalibratedClassifierCV。如果您正在解决分类问题,请使用 DecisionTreeClassifier。
工作示例:
from sklearn.datasets import load_iris
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import train_test_split
X, y= load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2, stratify=y)
pipe_dtr = Pipeline(steps=[('preprocessor', StandardScaler()),
('clf', DecisionTreeClassifier(random_state=62))])
params_dtr = {
'clf__max_depth' : np.arange(1,100,5),
'clf__min_samples_leaf' : [0.01, 0.1, 1]
}
gs_dtr = RandomizedSearchCV(estimator=pipe_dtr,
param_distributions=params_dtr,
n_iter=25,
scoring='accuracy',
cv=5)
gs_dtr.fit(X_train, y_train)
calib_pipe_dtr = Pipeline(steps=[('preprocessor', StandardScaler()),
('calibrator', CalibratedClassifierCV(gs_dtr.best_estimator_, cv='prefit'))])
calib_pipe_dtr.fit(X_train, y_train)
推荐阅读
- google-apps-script - 使用 CardService 的多选下拉菜单
- perl - Perl:如果找到匹配项,如何插入一行?
- .net - 使用 .NET 控制台应用程序监视剪贴板
- owl - protege 颗粒推断问题
- javascript - javascript中的数组数组到对象数组的数组
- ios - objc 方法在我的 JS 端没有被识别
- linux - AWS EC2 packet_write_wait:连接到 UNKNOWN 端口 65535:管道损坏
- ruby-on-rails - 添加新测试时,以前工作测试的权限被拒绝
- javascript - 问:在自执行功能中,谁被分配了数字 10
- windows - 浏览到我的自定义本地 HTTPS 服务器时如何避免无效证书警告?