r - 如何逐步提取 mlr3 调谐图?
问题描述
我的代码如下
library(mlr3verse)
library(mlr3pipelines)
library(mlr3filters)
library(paradox)
filter_importance = mlr_pipeops$get(
"filter",
filter = FilterImportance$new(learner = lrn("classif.ranger", importance = "impurity")),
param_vals = list(filter.frac = 0.7)
)
learner_classif = lrn(
"classif.ranger",
predict_type = "prob",
importance = "impurity",
num.trees = 500
)
polrn_classif = PipeOpLearner$new(learner_classif)
# create learner graph
glrn_classif = filter_importance %>>% polrn_classif
glrn_classif = GraphLearner$new(glrn_classif)
glrn_classif$predict_type = "prob"
# task
task = tsk("german_credit")
# set search_space
ps_classif = ParamSet$new(list(
ParamInt$new("classif.ranger.num.trees", lower = 300, upper = 500),
ParamDbl$new("classif.ranger.sample.fraction", lower = 0.7, upper = 0.8)
))
# auto tunning
at = AutoTuner$new(
learner = glrn_classif,
resampling = rsmp("cv", folds = 3),
measure = msr("classif.auc"),
search_space = ps_classif,
terminator = trm("evals", n_evals = 3),
tuner = tnr("random_search")
)
# sampling
rr = resample(task, at, rsmp("cv", folds = 2))
在我rr
从重采样和训练有素的学习者那里得到对象之后at
。请问如何提取这些步骤在做什么?
前任:
- 当我从
at
对象获得结果时,如何手动重新运行? - 每个步骤使用哪个样本(train_index,test_index)?
filter_importance
从步骤中选择了哪些变量?这一步中每个变量的得分是多少?
非常感谢 !!!
解决方案
为了能够在重新采样后摆弄模型,最好调用 resamplestore_models = TRUE
使用您的示例
library(mlr3verse)
set.seed(1)
rr <- resample(task,
at,
rsmp("cv", folds = 2),
store_models = TRUE)
完成重采样后,您可以访问生成对象的内部结构,如下所示:
要获取每个折叠中的行 ID:
rr$resampling$instance
#output
row_id fold
1: 5 1
2: 8 1
3: 9 1
4: 12 1
5: 13 1
---
996: 989 2
997: 993 2
998: 994 2
999: 995 2
1000: 996 2
有了这些和经过调整的自动调谐器,我们可以手动生成预测。
生成测试索引列表
rsample <- split(rr$resampling$instance$row_id,
rr$resampling$instance$fold)
迭代折叠并调整自动调谐器并预测:
lapply(1:2, function(i){
x <- rsample[[i]] #get the test row ids
task_test <- task$clone() #clone the task so we don't change the original task
task_test$filter(x) #filter on the test row ids
preds <- rr$learners[[i]]$predict(task_test) #use the trained autotuner and above filtered task
preds
}) -> preds_manual
检查这些预测是否与重新采样的输出相匹配
all.equal(preds_manual,
rr$predictions())
#output
TRUE
获取有关调整的信息
zz <- rr$data$learners()$learner
lapply(zz, function(x) x$tuning_result)
#output
[[1]]
classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1: 342 0.7931022 <list[7]>
x_domain classif.auc
1: <list[2]> 0.7981283
[[2]]
classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1: 407 0.7964164 <list[7]>
x_domain classif.auc
1: <list[2]> 0.7706533
插槽
zz[[1]]$learner$state$model$importance
包含有关filter_importance
步骤的信息
具体来说
lapply(zz, function(x) x$learner$state$model$importance$scores)
#output
[[1]]
amount status age
27.491369 25.776145 22.021369
duration purpose credit_history
18.732521 16.251643 14.884843
employment_duration savings property
11.225678 10.796583 9.078619
personal_status_sex present_residence installment_rate
8.914802 7.875384 7.491573
job number_credits other_installment_plans
6.293323 5.662485 5.345666
housing telephone other_debtors
4.869471 3.742213 3.548856
people_liable foreign_worker
2.632163 1.054919
[[2]]
amount duration age
26.764389 22.139400 20.749865
status purpose employment_duration
20.524764 11.793789 10.962301
credit_history installment_rate savings
10.416572 9.597835 9.491894
property present_residence job
9.403157 7.877391 6.760945
personal_status_sex housing other_installment_plans
6.699065 5.811131 5.710761
telephone other_debtors number_credits
4.716322 4.318972 3.974793
people_liable foreign_worker
3.196563 0.846520
包含特征的排名。尽管
lapply(zz, function(x) x$learner$state$model$importance$outtasklayout)
#output
[[1]]
id type
1: age integer
2: amount integer
3: credit_history factor
4: duration integer
5: employment_duration factor
6: installment_rate ordered
7: job factor
8: number_credits ordered
9: personal_status_sex factor
10: present_residence ordered
11: property factor
12: purpose factor
13: savings factor
14: status factor
[[2]]
id type
1: age integer
2: amount integer
3: credit_history factor
4: duration integer
5: employment_duration factor
6: housing factor
7: installment_rate ordered
8: job factor
9: personal_status_sex factor
10: present_residence ordered
11: property factor
12: purpose factor
13: savings factor
14: status factor
包含在过滤步骤之后保留的特征。
推荐阅读
- java - 为什么这个测试 junit 测试返回 400?
- python - Python 和 Tweepy:CSV 文件中的结果
- android - 在运行时将文本视图添加到 relativlayout
- swift - 从文本字段输入在现有表中创建新行。SwiftUI(故事板)
- r - RStudio README.Rmd 和 README.md 都应该分阶段使用'git commit --no-verify'来覆盖这个检查
- azure-data-lake - ADLS 是否支持符号或别名?
- c# - datagridview c#的参数超出范围异常
- android - 未使用导航组件将片段添加到后台堆栈
- reactjs - React 16+ 中的 React、Enzyme、Redux 单元测试连接组件
- google-sheets - 查找(并返回)包含公式的单元格范围