首页 > 解决方案 > 如何对数据中的多个组进行网格搜索?

问题描述

数据和目标

我的数据集有多个组。以下是一个简单的可重现示例:

library(tidyverse)

df <- tibble(
  participant = c(rep("A", 5), rep("B", 7)),
  Time = c(1:5, 1:7),
  x = c(rnorm(5, 0, 1), rnorm(7, 0, 10)),
  y = (0.5 * x) + (0.7 * x)
)  

在这里,我有 2 个参与者(AB),并且还为两个参与者记录了xy变量。

我想y_hat根据自定义函数进行估计,如下所示:

find_y_hat <- function(participant_data, param1, param2){
  
  participant_data %>% 
    mutate(y_hat = (param1 * x) + (param2 * x))
}  

例子:

param1如果我为and提供值param2,我会为参与者得到以下结果A

participant_A_data <- df %>% 
  filter(participant == "A")

find_y_hat(participant_A_data, 
           param1 = 0, param2 = 0.5)
# A tibble: 5 x 5
  participant  Time      x      y   y_hat
  <chr>       <int>  <dbl>  <dbl>   <dbl>
1 A               1 -0.336 -0.404 -0.168 
2 A               2  1.24   1.49   0.619 
3 A               3  0.520  0.624  0.260 
4 A               4 -0.438 -0.525 -0.219 
5 A               5  0.122  0.147  0.0612

我的目标是为每个参与者使用两个参数的几个不同值来应用此函数。所以,我有以下参数网格:

paramz_grid <- expand.grid(
  param1 = c(0, 0.5, 0.6, 0.7),
  param2 = c(0, 0.5, 0.6, 0.7)
)
paramz_grid
param1 param2
1     0.0    0.0
2     0.5    0.0
3     0.6    0.0
4     0.7    0.0
5     0.0    0.5
6     0.5    0.5
7     0.6    0.5
8     0.7    0.5
9     0.0    0.6
10    0.5    0.6
11    0.6    0.6
12    0.7    0.6
13    0.0    0.7
14    0.5    0.7
15    0.6    0.7
16    0.7    0.7

我试图做的

我想我可以使用嵌套和purrr::map函数来实现目标。但我不知道如何组合每个参与者的输入(总计 = 2)和每个参数组合(总计 = 16)。我认为pmap可以使用,但我不知道如何使用它。请指导我。

df %>% 
   group_by(participant) %>% 
   group_nest()
# A tibble: 2 x 2
  participant           data
* <chr>       <list<tibble>>
1 A                  [5 x 3]
2 B                  [7 x 3]

编辑

我尝试了以下方法,但没有成功:

paramz_grid <- paramz_grid %>% 
  rowid_to_column()

l1 <- paramz_grid %>% 
  split(.$rowid)


l2 <- df %>% 
  split(.$participant)



pmap(.l = list(l1, l2, rowz = 1:nrow(paramz_grid)),
     .f = ~find_y_hat(participant_data = l2,
                     l1[[rowz]]$param1,  l1[[rowz]]$param2))

Error: Element 2 of `.l` must have length 1 or 16, not 2
Run `rlang::last_error()` to see where the error occurred.

标签: rtidyverse

解决方案


我知道您可能正在寻找tidyverse解决方案,但作为替代方案,我可以为您提供这种data.table方法:

library(data.table)
set.seed(123)
df <- data.table(
    participant = c(rep("A", 5), rep("B", 7)),
    Time = c(1:5, 1:7),
    x = c(rnorm(5, 0, 1), rnorm(7, 0, 10)))[
        , y := (0.5 * x) + (0.7 * x)]

paramz_grid <- expand.grid(
    param1 = c(0, 0.5, 0.6, 0.7),
    param2 = c(0, 0.5, 0.6, 0.7)
)

df[, paste0("y_hat: ", apply(paramz_grid,1,paste, collapse=",")):=
       lapply(seq_len(nrow(paramz_grid)), 
              function(z) (paramz_grid[z,1] * x) + (paramz_grid[z,2] * x)), 
   by=participant][]
#>     participant Time            x            y y_hat: 0,0 y_hat: 0.5,0
#>  1:           A    1  -0.56047565  -0.67257078          0  -0.28023782
#>  2:           A    2  -0.23017749  -0.27621299          0  -0.11508874
#>  3:           A    3   1.55870831   1.87044998          0   0.77935416
#>  4:           A    4   0.07050839   0.08461007          0   0.03525420
#>  5:           A    5   0.12928774   0.15514528          0   0.06464387
#>  6:           B    1  17.15064987  20.58077984          0   8.57532493
#>  7:           B    2   4.60916206   5.53099447          0   2.30458103
#>  8:           B    3 -12.65061235 -15.18073482          0  -6.32530617
#>  9:           B    4  -6.86852852  -8.24223422          0  -3.43426426
#> 10:           B    5  -4.45661970  -5.34794364          0  -2.22830985
#> 11:           B    6  12.24081797  14.68898157          0   6.12040899
#> 12:           B    7   3.59813827   4.31776592          0   1.79906914
#>     y_hat: 0.6,0 y_hat: 0.7,0 y_hat: 0,0.5 y_hat: 0.5,0.5 y_hat: 0.6,0.5
#>  1:  -0.33628539  -0.39233295  -0.28023782    -0.56047565    -0.61652321
#>  2:  -0.13810649  -0.16112424  -0.11508874    -0.23017749    -0.25319524
#>  3:   0.93522499   1.09109582   0.77935416     1.55870831     1.71457915
#>  4:   0.04230503   0.04935587   0.03525420     0.07050839     0.07755923
#>  5:   0.07757264   0.09050141   0.06464387     0.12928774     0.14221651
#>  6:  10.29038992  12.00545491   8.57532493    17.15064987    18.86571486
#>  7:   2.76549724   3.22641344   2.30458103     4.60916206     5.07007827
#>  8:  -7.59036741  -8.85542864  -6.32530617   -12.65061235   -13.91567358
#>  9:  -4.12111711  -4.80796996  -3.43426426    -6.86852852    -7.55538137
#> 10:  -2.67397182  -3.11963379  -2.22830985    -4.45661970    -4.90228167
#> 11:   7.34449078   8.56857258   6.12040899    12.24081797    13.46489977
#> 12:   2.15888296   2.51869679   1.79906914     3.59813827     3.95795210
#>     y_hat: 0.7,0.5 y_hat: 0,0.6 y_hat: 0.5,0.6 y_hat: 0.6,0.6 y_hat: 0.7,0.6
#>  1:    -0.67257078  -0.33628539    -0.61652321    -0.67257078    -0.72861834
#>  2:    -0.27621299  -0.13810649    -0.25319524    -0.27621299    -0.29923074
#>  3:     1.87044998   0.93522499     1.71457915     1.87044998     2.02632081
#>  4:     0.08461007   0.04230503     0.07755923     0.08461007     0.09166091
#>  5:     0.15514528   0.07757264     0.14221651     0.15514528     0.16807406
#>  6:    20.58077984  10.29038992    18.86571486    20.58077984    22.29584483
#>  7:     5.53099447   2.76549724     5.07007827     5.53099447     5.99191068
#>  8:   -15.18073482  -7.59036741   -13.91567358   -15.18073482   -16.44579605
#>  9:    -8.24223422  -4.12111711    -7.55538137    -8.24223422    -8.92908707
#> 10:    -5.34794364  -2.67397182    -4.90228167    -5.34794364    -5.79360561
#> 11:    14.68898157   7.34449078    13.46489977    14.68898157    15.91306337
#> 12:     4.31776592   2.15888296     3.95795210     4.31776592     4.67757975
#>     y_hat: 0,0.7 y_hat: 0.5,0.7 y_hat: 0.6,0.7 y_hat: 0.7,0.7
#>  1:  -0.39233295    -0.67257078    -0.72861834    -0.78466591
#>  2:  -0.16112424    -0.27621299    -0.29923074    -0.32224849
#>  3:   1.09109582     1.87044998     2.02632081     2.18219164
#>  4:   0.04935587     0.08461007     0.09166091     0.09871175
#>  5:   0.09050141     0.15514528     0.16807406     0.18100283
#>  6:  12.00545491    20.58077984    22.29584483    24.01090982
#>  7:   3.22641344     5.53099447     5.99191068     6.45282688
#>  8:  -8.85542864   -15.18073482   -16.44579605   -17.71085728
#>  9:  -4.80796996    -8.24223422    -8.92908707    -9.61593993
#> 10:  -3.11963379    -5.34794364    -5.79360561    -6.23926758
#> 11:   8.56857258    14.68898157    15.91306337    17.13714516
#> 12:   2.51869679     4.31776592     4.67757975     5.03739358

reprex 包于 2021-03-14 创建(v1.0.0)

如果您更喜欢保留y_hat为列表列,您可以修改上面的操作

df[, y_hat:=list(.(lapply(seq_len(nrow(paramz_grid)), 
              function(z) (paramz_grid[z,1] * x) + (paramz_grid[z,2] * x)))), 
   by=participant][]
``

推荐阅读