python - 如何为流水线式多输出分类器绘制树?
问题描述
我想解释我的模型,了解为什么这个模型给我的标签是 1 或 0。,所以我想使用 xgboost 的 plot_tree 函数。我的问题是多标签分类问题;我写了以下代码;
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, shuffle=True, random_state=42)
model = MultiOutputClassifier(
xgb.XGBClassifier(objective="binary:logistic",
colsample_bytree = 0.5,
gamma = 0.1
))
#Define a pipeline
pipeline = Pipeline([("preprocessing", col_transformers), ("XGB", model)])
pipeline.fit(X_train, y_train)
predicted = pipeline.predict(X_test)
xgb.plot_tree(pipeline, num_trees=4)
这段代码给了我错误;
“管道”对象没有属性“get_dump”
如果我更改代码;
xgb.plot_tree(pipeline.named_steps["XGB"], num_trees=4)
“MultiOutputClassifier”对象没有属性“get_dump”
我怎么解决这个问题?
解决方案
您只能在或实例上使用该plot_tree
功能。当您传递一个对象时,您的第一种情况会失败,而在第二种情况下,您正在传递对象。Booster
XGBModel
Pipeline
MultiOutputClassifier
相反,您必须传递适合的XGBClassifier
对象。但是,请注意MultiOutputClassifier
实际工作方式:
该策略包括为每个目标拟合一个分类器。
这意味着您将有一个适合每个标签的模型。
您可以使用 的estimators_
属性访问它们MultiOutputClassifier
。例如,您可以像这样检索第一个标签的模型:
xgb.plot_tree(pipeline.named_steps["XGB"].estimators_[0], num_trees=4)
如果想要全部,则必须遍历该estimators_
属性返回的列表。
推荐阅读
- mongodb - 如何在字符串字段中使用日期条件删除 mongodb 中的文档?
- javascript - 如何一次调用ajax请求在多个组件实例中加载数据
- report - 如何创建表格的自定义“摘要/总计”
- google-chrome - 无法在 Mozilla 中使用 JMETER 记录网络流量在将端口更改为 8080 时出错
- google-chrome - Chrome SVG 渲染工件
- c++ - 泛化我的可变参数模板函数时出错
- python - 由 PyQt4 创建并在 python 中执行的 GUI 没有打开
- azure - 了解 Azure CDN
- macos - CAN 所需的原始套接字在 MacOS 下不起作用 - 套接字:协议不支持地址系列
- symfony - 奏鸣曲默认过滤器值