首页 > 解决方案 > 指定要在留一法(jack-knife)交叉验证中使用的选定数据范围,以便在 caret::train 函数中使用

问题描述

这个问题建立在我在这里提出的问题之上:Creating data partitions over a selected range of data to be feed into caret::train function for cross-validation)。

我正在使用的数据如下所示:

df <- data.frame(Effect = rep(seq(from = 0.05, to = 1, by = 0.05), each = 5), Time = rep(c(1:20,1:20), each = 5), Replicate = c(1:5))

本质上,我想做的是创建自定义分区,例如由caret::groupKFold函数生成的分区,但这些折叠要超过指定范围(即 > 15 天),并且每次折叠保留一个点作为测试集和以及用于训练的所有其他数据。这将在每次迭代中重复,直到指定范围内的每个点都被用作测试集。@Missuse 为此编写了一些代码,该代码接近上述链接中该问题的所需输出。

我会尝试向您展示所需的输出,但老实说 caret::groupKFold 函数的输出让我感到困惑,所以希望上面的描述就足够了。很高兴尝试澄清!

标签: rcross-validationr-caretdata-partitioning

解决方案


这是您可以使用以下方法创建所需分区的一种方法tidyverse

library(tidyverse)

df %>%
  mutate(id = row_number()) %>% #create a column called id which will hold the row numbers
  filter(Time > 15) %>% #subset data frame according to your description 
  split(.$id)  %>% #split the data frame into lists by id (row number)
  map(~ .x %>% select(id) %>% #clean up so it works with indexOut argument in trainControl
        unlist %>%
        unname) -> folds_cv

编辑:似乎indexOut参数没有按预期执行,但是index在使一个参数这样做之后,folds_cv可以使用以下方法得到相反的结果setdiff

folds_cv <- lapply(folds_cv, function(x) setdiff(1:nrow(df), x))

现在:

test_control <- trainControl(index = folds_cv,
                             savePredictions = "final")


quad.lm2 <- train(Time ~ Effect,
                  data = df,
                  method = "lm",
                  trControl = test_control)

带有警告:

Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
  There were missing values in resampled performance measures.
> quad.lm2
Linear Regression 

200 samples
  1 predictor

No pre-processing
Resampling: Bootstrapped (50 reps) 
Summary of sample sizes: 199, 199, 199, 199, 199, 199, ... 
Resampling results:

  RMSE          Rsquared  MAE         
  3.552714e-16  NaN       3.552714e-16

Tuning parameter 'intercept' was held constant at a value of TRUE

所以每个重新采样使用 199 行并在 1 上进行预测,对我们想要一次保留的所有 50 行重复。这可以通过以下方式验证:

quad.lm2$pred

为什么Rsquared失踪我不确定我会更深入地挖掘。


推荐阅读