python - 在多标签分类中将参数传递给较低级别的 XGBoost 估计器
问题描述
我有一个多标签分类问题,我想为每个标签(总共 4 个)训练一个 XGBoost 模型;然后我结合了四个 XGBoost 估计器,这要归功于sklearn.multioutput.MultiOutputClassifier
(链接)。
另外,我想对 XGBoost 的超参数进行随机搜索,这要归功于RandomizedSearchCV
.
下面有一些可重现的代码来解释我的意图。
import xgboost as xgb
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.multioutput import MultiOutputClassifier
from sklearn.datasets import make_multilabel_classification
# create dataset
X, y = make_multilabel_classification(n_samples=3000, n_features=50, n_classes=4, n_labels=1,
allow_unlabeled=False, random_state=42)
# Split dataset into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)
# hyper-parameters space for the random search
random_grid = {
'n_estimators': [200, 300, 400],
'learning_rate': [0.05, 0.1, 0.2],
'max_depth': [3, 4, 5],
'min_child_weight': [1, 3]
}
xgb_estimator = xgb.XGBClassifier(objective='binary:logistic')
xgb_model = MultiOutputClassifier(xgb_estimator)
# random search instance
xgb_random_search = RandomizedSearchCV(
estimator=xgb_model, param_distributions=random_grid,
scoring=['accuracy'], refit='accuracy', n_iter=2, cv=3, verbose=True, random_state=1234, n_jobs=2
)
# fit the random search
xgb_random_search.fit(X_train, y_train)
但是,此代码给出以下(汇总)错误:
ValueError: Invalid parameter n_estimators for estimator MultiOutputClassifier.
Check the list of available parameters with `estimator.get_params().keys()`
事实上,在运行错误消息建议的代码行之后,我意识到我将超参数传递random_grid
给被MultiOutputClassifier
调用xgb_model
而不是 XGBoost 被调用xgb_estimator
,这是“较低级别”的估计器(因为它是“包含的”内xgb_model
)。
问题是:如何将超参数传递random_grid
给“较低级别”的 XGBoost 估计器?我觉得通过一些**kwargs
操作是可能的,但是经过一些试验我没有找到使用它们的方法。
解决方案
如果你运行xgb_model.get_params()
,你会发现参数的名称都是以estimator__
(双下划线)开头的。所以你的参数空间应该看起来像
random_grid = {
'estimator__n_estimators': [200, 300, 400],
'estimator__learning_rate': [0.05, 0.1, 0.2],
'estimator__max_depth': [3, 4, 5],
'estimator__min_child_weight': [1, 3]
}
这与其他 sklearn 嵌套模型(如Pipeline
和)一致ColumnTransformer
。
推荐阅读
- php - 从 PHP 下载 excel 文件
- angular - Angular 从 Observable 获取特定值
- javascript - 引导表单:如何正确处理分配给“更改”验证的输入的“无效”类?轮廓颜色与“无效反馈”不匹配
- maven - 我可以使用 Maven 生成一个可以使用不同的 testng xml 运行不同的 Selenium TestNG 测试的 Jar
- javascript - 如何根据变量显示不同的图像?
- java - ClassNotFoundException:项目编译后抛出/org/jdom2/JDOMException
- r - 将具有不同字符大小的列表转换为数据框
- python - 如何修复 Web2Py ”
(外键约束失败)”? - jquery - 使用函数onselect时JQuery Air DatePicker为空
- django - 在生产中执行迁移,但它们不是在我的服务器中创建的