multilabel-classification - 如何在 mlr 中使用调整参数进行多标签分类?
问题描述
问题
我正在尝试使用 mlr 包在 r 中运行多标签分类。我使用https://www.rdocumentation.org/packages/mlr/versions/2.19.0/topics/makeMultilabelClassifierChainsWrapper来实现多标签分类。但我需要添加超参数调整。这似乎会产生各种问题。我按照https://mlr.mlr-org.com/articles/tutorial/tune.html上的示例调整参数。tuneParams 需要参数 resample 并且我被卡住了。
示例数据
age <- c(round(rnorm(120,mean = 50,sd = 10)))
sex <- c(round(rnorm(120,mean = 0.5,sd = 0.2)))
l1 <- as.logical(c(round(rnorm(120,mean = 0.5,sd = 0.2))))
l2 <- as.logical(c(round(rnorm(120,mean = 0.5,sd = 0.2))))
l3 <- as.logical(c(round(rnorm(120,mean = 0.5,sd = 0.2))))
l4 <- as.logical(c(round(rnorm(120,mean = 0.5,sd = 0.2))))
data <- as.data.frame(cbind(age,sex,l1,l2,l3,l4))
实际上,我有 12 个标签,但为了便于查看,我省略了其他标签。这个想法是 l1 直到 l4 是逻辑向量。不知何故,这不起作用,所以我希望你能解决这个问题。但请注意,这不是我的主要问题。
代码
task <- makeMultilabelTask(data = data, target = label_bact)
ps <- makeParamSet(
makeDiscreteParam("ntree",values = c(50,100,150,200,300,500,550)),
makeDiscreteParam("mtry",values = c(1,2,3,4,5))
)
ctrl <- makeTuneControlGrid()
rdesc <- makeResampleDesc(method = "CV",iters = 5, predict = "test",
stratify.cols = c(l1,l2,l3,l4)
measure <- acc
learner <- "classif.randomForest"
lrn <- makeLearner(learner)
lrn <- makeMultilabelClassifierChainsWrapper(lrn, order = NULL)
lrn <- setPredictType(lrn,"prob")
res <- tuneParams(lrn,task = task,resample = rdesc, par.set = ps,control = ctrl)
错误
我得到的错误:
Error in tuneParams(lrn, task = task, resample = rdesc, par.set = ps, :
Assertion on 'resample.fun' failed: Must be a function, not 'CVDesc/ResampleDesc'.
所以我添加了代码行:
r <- resample(learner = lrn,task = task,rdesc)
这告诉我
Error in makeResampleInstance(resampling, task = task) :
Stratification for tasks of type 'multilabel' not supported
查看
这通过以下方式得到证实:
>rdesc
Resample description: cross-validation with 5 iterations.
Predict: test
Stratification: FALSE
问题
- 所以第一个问题是如何解决多个结果标签的分层(在 makeResampleDesc 函数中)?
- 第二个问题是如何使 tuneParams 功能起作用?
- 相关问题是有没有办法跳过重采样参数,因为我已经在这些函数之外进行了 CV 和分层?
提前致谢!
解决方案
尝试这个:
rf <- makeLearner("classif.randomForest", predict.type = "response", par.vals = list(ntree = 200, mtry = 3))
rf$par.vals <- list(importance = TRUE)
rf=makeDownsampleWrapper(rf, dw.stratify = TRUE)
rf=makeMultilabelClassifierChainsWrapper( rf )
rf_param <- makeParamSet(
makeIntegerParam("ntree",lower = 50, upper = 500),
makeIntegerParam("mtry", lower = 3, upper = 10),
makeIntegerParam("nodesize", lower = 10, upper = 50)
)
rancontrol <- makeTuneControlRandom(maxit = 5L)
#set 3 fold cross validation
set_cv <- makeResampleDesc("CV",iters = 3L)
#hypertuning
rf_tune <- tuneParams(learner = rf, resampling = set_cv, task =base_pred.task, par.set = rf_param, control = rancontrol, measures = multilabel.hamloss)
这个对我有用!我希望它对你有用
推荐阅读
- jquery - 使用 Ajax Jquery 时 DataTable 不刷新
- fable-r - 在 tidyverts 包中按键创建时间序列交叉验证切片
- java - 以下 switch 语句有什么问题?
- meshlab - 默认情况下,Meshlab 如何在导出网格中计算法线?
- javascript - 没有 Node.js 我怎么能做 tweetnacl.sealedbox.seal?
- node.js - waitForSelector 和 querySelectorAll 与 puppeteer
- windows - 如何加速这个 Bash mv 脚本?
- mysql - MariaDB 无法更改/禁用 sql 模式
- php - 在 AcceptSuite\create-an-accept-payment-transaction.php 中找不到类 'net\authorize\api\contract\v1\MerchantAuthenticationType'
- javascript - React Native 中的可编辑文本