machine-learning - 从 mlr 包的 resample 函数中获取特定的随机森林变量重要性度量
问题描述
我正在使用mlr包的resample()
功能对随机森林模型进行 4000 次子采样(下面的代码片段)。
如您所见,要在resample()
我使用randomForest包中创建随机森林模型。
我想获得每个子样本迭代的随机森林模型的重要性结果(所有类的准确性平均下降)。作为重要性衡量标准,我现在可以得到的是基尼指数的平均下降。
我可以从 mlr 的源代码中看到,makeRLearner.classif.randomForest 中的 函数getFeatureImportanceLearner.classif.randomForest()
(第 69 行)使用函数(第 83 行)从类的结果对象中获取重要性值。但是从源代码(第 73 行)可以看出,它使用 2L 作为默认值。我希望它使用 1L(第 75 行)作为值(平均精度下降)。randomForest::importance()
randomForest
如何将 2L 的值传递给resample()
函数(下面代码中的“extract = getFeatureImportance”行),以便getFeatureImportanceLearner.classif.randomForest()
函数获取该值并设置ctrl$type = 2L
(第 73 行)?
rf_task <- makeClassifTask(id = 'task',
data = data[, -1], target = 'target_var',
positive = 'positive_var')
rf_learner <- makeLearner('classif.randomForest', id = 'random forest',
par.vals = list(ntree = 1000, importance = TRUE),
predict.type = 'prob')
base_subsample_instance <- makeResampleInstance(rf_boot_desc, rf_task)
rf_subsample_result <- resample(rf_learner, rf_task,
base_subsample_instance,
extract = getFeatureImportance,
measures = list(acc, auc, tpr, tnr,
ppv, npv, f1, brier))
我的解决方案:下载了 mlr 包的源代码。将源文件第 73 行更改为 1L ( https://github.com/mlr-org/mlr/blob/v2.15.0/R/RLearner_classif_randomForest.R )。从命令行安装包并使用它。不是最佳解决方案,而是解决方案。
解决方案
您提供了许多实际上与您的问题无关的细节,至少我是如何理解的。所以我写了一个包含答案的简单 MWE。这个想法是您必须编写一个简短的包装器,getFeatureImportance
以便您可以传递自己的参数。粉丝purrr
可以这样做,但在这里我是手动purrr::partial(getFeatureImportance, type = 2)
编写的。myExtractor
library(mlr)
rf_learner <- makeLearner('classif.randomForest', id = 'random forest',
par.vals = list(ntree = 100, importance = TRUE),
predict.type = 'prob')
measures = list(acc, auc, tpr, tnr,
ppv, npv, f1, brier)
myExtractor = function(.model, ...) {
getFeatureImportance(.model, type = 2, ...)
}
res = resample(rf_learner, sonar.task, cv10,
measures = measures, extract = myExtractor)
# first feature importance result:
res$extract[[1]]
# all values in a matrix:
sapply(res$extract, function(x) x$res)
如果你想做一个引导学习者,也许你也应该看看makeBaggingWrapper
而不是通过resample
.
推荐阅读
- html - CSS 媒体查询不能在移动设备上运行,但可以在桌面上运行
- linux - “错误:EACCES:权限被拒绝” - 在 AWS EC2 实例上运行 Meteor 应用程序 (OHIF)
- python - from django.db.models import Q, Count, F, JSONField ImportError: cannot import name 'JSONField'
- laravel - 在 Livewire 中单击之前如何防止使用参数
- c - c时钟()在不同的操作系统
- r - 如何删除包含某些文本的 tibble 行?
- apache-flink - Flink:使用 CSV 文件的事件时间聚合
- flutter - 单击单选按钮时如何更改容器
- python - 即使在执行程序后,Postgresql 数据库上的连接仍处于活动状态
- cypress - 赛普拉斯:如何通过检查 URL 有条件地跳过测试