首页 > 解决方案 > 在 CaretList 中使用配方

问题描述

我一直在尝试使用 CaretList 中的配方来训练模型列表。

library(data.table)
library(caretEnsemble)
library(recipe)
library(doParallel)


data(iris)
dat <- setDT(iris)[Species %in% c("setosa","versicolor"),]
#my real dataset has a combination of factor and numeric columns
dat[,`:=`(
factor_test = factor(ifelse(Sepal.Length>4.8,1,0),
Species = factor(Species, levels = c("setosa","versicolor"))] # adjust factor levels for species

blueprint <- recipe(Species ~., data=dat) %>% step_nzv(all_predictors()) %>%  step_pca(matches("Sepal.Width|Petal.Length"), prefix = "WidthLength_",  threshold = .95)

# set seeds for tuning caret models
set.seed(42)
seed.list <- list()
for (i in 1:100) {
  seed.list[[i]] <- sample.int(n = 100000000, size = 50000)
}
seed.list[[101]] <- sample.int(n = 100000000, size = 1)

# LGOCV in lieu of repeated cross-validation on separate train and test sets
myControl <- trainControl(method = "LGOCV", 
                              number = 3,
                              p = 0.7,
                              summaryFunction = multiClassSummary, 
                              classProbs = TRUE, 
                              verboseIter = TRUE,
                              seeds = seed.list,  
                              savePredictions = "all",
                              returnResamp = "all",
                              allowParallel = TRUE)

cl <- makePSOCKcluster(detectCores()-1)
registerDoParallel(cl)

test_list <- caretList(blueprint,
                       data = dat,
                       methodList = list("glm","gbm"),
                       metric = "AUC")
stopCluster(cl)

CaretList 文档说要为“modelList”提供与“train”中相同的输入,但我不断收到错误“extractCaretTarget.default(...) 中的错误:参数“y”丢失,没有默认值”,即使配方是正确的。我在这里做错了什么?PS如果LGOCV的设置不正确请评论,这是我第一次使用...

标签: rmachine-learningr-carettraining-datarecipe

解决方案


刚刚遇到同样的问题。extractCaretTarget 是一种 S3 方法,配方类尚不存在该方法,它使用此方法创建重采样。如果重新采样是预先传递的,它会尝试生成它们并需要结果 y 来进行分层折叠

if (is.null(trControl$index)) {
   target <- extractCaretTarget(...)
        trControl <- trControlCheck(x = trControl, y = target)
    }

可以通过在声明 trainControl 时显式设置 index 参数来解决这个问题

if(x$method=="boot" | x$method=="adaptive_boot"){
      x$index <- createResample(y, times = x$number, list = TRUE)
    } else if(x$method=="cv" | x$method=="adaptive_cv"){
      x$index  <- createFolds(y, k = x$number, list = TRUE, returnTrain = TRUE)
    } else if(x$method=="repeatedcv"){
      x$index <- createMultiFolds(y, k = x$number, times = x$repeats)
    } else if(x$method=="LGOCV" | x$method=="adaptive_LGOCV"){
      x$index <- createDataPartition(
        y,
        times = x$number,
        p = 0.5,
        list = TRUE,
        groups = min(5, length(y)))
    }

推荐阅读