首页 > 解决方案 > 如何使用 purrr 中的 cross 和 pmap 对多个模型执行 k 折交叉验证?

问题描述

我正在尝试使用函数从头开始构建交叉验证过程tidyverse,当我尝试在迭代、折叠的各种组合中使用cross和迭代我的模型拟合和预测生成函数时,我遇到了困难,pmap和模型。

这是我正在尝试做的一个最小示例。如果我手动应用该功能(例如,),该功能将起作用,并且我尝试过my_function(1, 1, formula_list[[1]])的各种版本都会产生所需长度和(我认为的)结构的列表。cross但是当我尝试申请pmap该列表时,我收到一个关于“未使用的参数”的错误。

library(purrr)
library(dplyr)

df <- data.frame(i = rep(seq(2), each = 50),
                 k = rep(seq(5), times = 20),
                 y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100))

formula_list <- list( as.formula(y ~ x1), as.formula(y ~ x1 + x2))

my_function <- function(my_i, my_k, my_formula) {

    train <- filter(df, i == my_i & k != my_k)
    test <- filter(df, i == my_i & k == my_k)

    mod <- lm(my_formula, data = train)

    test$pred <- predict(mod, newdata = test)

    return(test)

}

# this throws an error about unused arguments
crossArg <- cross3(seq(2), seq(5), formula_list)
results <- pmap(crossArg, my_function)

# this throws the same error
crossArg <- cross(list(seq(2), seq(5), formula_list))
results <- pmap(crossArg, my_function)

我一定错过了一些关于 语法的基本点pmap,但是我查看了文档和一些在线示例,但仍然卡住了。

标签: rpurrr

解决方案


这是您正在寻找的 -> 第一枪:

#transforming your crossArg object into a 'simpler' list ie sort of 'flatten' it a bit
xx <- do.call(rbind, crossArgs)
#will give you this
xx
  [,1] [,2] [,3]      
 [1,] 1    1    Expression
 [2,] 2    1    Expression
 [3,] 1    2    Expression
 [4,] 2    2    Expression
 [5,] 1    3    Expression
 [6,] 2    3    Expression
...
#which you can then pmap like this 
results <- pmap(list(xx[ ,1], xx[,2], xx[ ,3]), .f = my_function)
results
[[1]]
   i k        y      x1       x2    pred
1  1 1  1.06302  1.9470 -0.13058 -0.5076
2  1 1 -0.26102  0.2096  0.64801 -0.3544
3  1 1 -1.44488  0.6056  1.13862 -0.3893
4  1 1  1.94536  0.1976 -0.10705 -0.3533
...
[[20]]
   i k       y       x1       x2      pred
1  2 5 -0.1085  0.76503  0.87501 -0.123588
2  2 5 -0.6337 -0.72294 -0.35574  0.256372
3  2 5 -0.1284  0.98152 -0.68990 -0.363973
4  2 5 -1.0502  1.03324  0.05394 -0.302769
5  2 5  1.1303  0.05811 -0.28898 -0.004556
6  2 5  0.2425 -0.56192  0.76655  0.320239
7  2 5 -0.6825  0.97010  0.51890 -0.231752
8  2 5 -0.7992  0.07324 -0.20911 -0.001270
9  2 5 -0.2876  0.87090 -0.48919 -0.304710
10 2 5 -0.1145  1.38314  1.89403 -0.227532

PS:真的很喜欢pmap()那种方式的应用,以及你在那里提供的一个很好的reprex来帮助解决......


推荐阅读