r - 如何对数据中的多个组进行网格搜索?
问题描述
数据和目标
我的数据集有多个组。以下是一个简单的可重现示例:
library(tidyverse)
df <- tibble(
participant = c(rep("A", 5), rep("B", 7)),
Time = c(1:5, 1:7),
x = c(rnorm(5, 0, 1), rnorm(7, 0, 10)),
y = (0.5 * x) + (0.7 * x)
)
在这里,我有 2 个参与者(A
和B
),并且还为两个参与者记录了x
和y
变量。
我想y_hat
根据自定义函数进行估计,如下所示:
find_y_hat <- function(participant_data, param1, param2){
participant_data %>%
mutate(y_hat = (param1 * x) + (param2 * x))
}
例子:
param1
如果我为and提供值param2
,我会为参与者得到以下结果A
:
participant_A_data <- df %>%
filter(participant == "A")
find_y_hat(participant_A_data,
param1 = 0, param2 = 0.5)
# A tibble: 5 x 5
participant Time x y y_hat
<chr> <int> <dbl> <dbl> <dbl>
1 A 1 -0.336 -0.404 -0.168
2 A 2 1.24 1.49 0.619
3 A 3 0.520 0.624 0.260
4 A 4 -0.438 -0.525 -0.219
5 A 5 0.122 0.147 0.0612
我的目标是为每个参与者使用两个参数的几个不同值来应用此函数。所以,我有以下参数网格:
paramz_grid <- expand.grid(
param1 = c(0, 0.5, 0.6, 0.7),
param2 = c(0, 0.5, 0.6, 0.7)
)
paramz_grid
param1 param2
1 0.0 0.0
2 0.5 0.0
3 0.6 0.0
4 0.7 0.0
5 0.0 0.5
6 0.5 0.5
7 0.6 0.5
8 0.7 0.5
9 0.0 0.6
10 0.5 0.6
11 0.6 0.6
12 0.7 0.6
13 0.0 0.7
14 0.5 0.7
15 0.6 0.7
16 0.7 0.7
我试图做的
我想我可以使用嵌套和purrr::map
函数来实现目标。但我不知道如何组合每个参与者的输入(总计 = 2)和每个参数组合(总计 = 16)。我认为pmap
可以使用,但我不知道如何使用它。请指导我。
df %>%
group_by(participant) %>%
group_nest()
# A tibble: 2 x 2
participant data
* <chr> <list<tibble>>
1 A [5 x 3]
2 B [7 x 3]
编辑
我尝试了以下方法,但没有成功:
paramz_grid <- paramz_grid %>%
rowid_to_column()
l1 <- paramz_grid %>%
split(.$rowid)
l2 <- df %>%
split(.$participant)
pmap(.l = list(l1, l2, rowz = 1:nrow(paramz_grid)),
.f = ~find_y_hat(participant_data = l2,
l1[[rowz]]$param1, l1[[rowz]]$param2))
Error: Element 2 of `.l` must have length 1 or 16, not 2
Run `rlang::last_error()` to see where the error occurred.
解决方案
我知道您可能正在寻找tidyverse
解决方案,但作为替代方案,我可以为您提供这种data.table
方法:
library(data.table)
set.seed(123)
df <- data.table(
participant = c(rep("A", 5), rep("B", 7)),
Time = c(1:5, 1:7),
x = c(rnorm(5, 0, 1), rnorm(7, 0, 10)))[
, y := (0.5 * x) + (0.7 * x)]
paramz_grid <- expand.grid(
param1 = c(0, 0.5, 0.6, 0.7),
param2 = c(0, 0.5, 0.6, 0.7)
)
df[, paste0("y_hat: ", apply(paramz_grid,1,paste, collapse=",")):=
lapply(seq_len(nrow(paramz_grid)),
function(z) (paramz_grid[z,1] * x) + (paramz_grid[z,2] * x)),
by=participant][]
#> participant Time x y y_hat: 0,0 y_hat: 0.5,0
#> 1: A 1 -0.56047565 -0.67257078 0 -0.28023782
#> 2: A 2 -0.23017749 -0.27621299 0 -0.11508874
#> 3: A 3 1.55870831 1.87044998 0 0.77935416
#> 4: A 4 0.07050839 0.08461007 0 0.03525420
#> 5: A 5 0.12928774 0.15514528 0 0.06464387
#> 6: B 1 17.15064987 20.58077984 0 8.57532493
#> 7: B 2 4.60916206 5.53099447 0 2.30458103
#> 8: B 3 -12.65061235 -15.18073482 0 -6.32530617
#> 9: B 4 -6.86852852 -8.24223422 0 -3.43426426
#> 10: B 5 -4.45661970 -5.34794364 0 -2.22830985
#> 11: B 6 12.24081797 14.68898157 0 6.12040899
#> 12: B 7 3.59813827 4.31776592 0 1.79906914
#> y_hat: 0.6,0 y_hat: 0.7,0 y_hat: 0,0.5 y_hat: 0.5,0.5 y_hat: 0.6,0.5
#> 1: -0.33628539 -0.39233295 -0.28023782 -0.56047565 -0.61652321
#> 2: -0.13810649 -0.16112424 -0.11508874 -0.23017749 -0.25319524
#> 3: 0.93522499 1.09109582 0.77935416 1.55870831 1.71457915
#> 4: 0.04230503 0.04935587 0.03525420 0.07050839 0.07755923
#> 5: 0.07757264 0.09050141 0.06464387 0.12928774 0.14221651
#> 6: 10.29038992 12.00545491 8.57532493 17.15064987 18.86571486
#> 7: 2.76549724 3.22641344 2.30458103 4.60916206 5.07007827
#> 8: -7.59036741 -8.85542864 -6.32530617 -12.65061235 -13.91567358
#> 9: -4.12111711 -4.80796996 -3.43426426 -6.86852852 -7.55538137
#> 10: -2.67397182 -3.11963379 -2.22830985 -4.45661970 -4.90228167
#> 11: 7.34449078 8.56857258 6.12040899 12.24081797 13.46489977
#> 12: 2.15888296 2.51869679 1.79906914 3.59813827 3.95795210
#> y_hat: 0.7,0.5 y_hat: 0,0.6 y_hat: 0.5,0.6 y_hat: 0.6,0.6 y_hat: 0.7,0.6
#> 1: -0.67257078 -0.33628539 -0.61652321 -0.67257078 -0.72861834
#> 2: -0.27621299 -0.13810649 -0.25319524 -0.27621299 -0.29923074
#> 3: 1.87044998 0.93522499 1.71457915 1.87044998 2.02632081
#> 4: 0.08461007 0.04230503 0.07755923 0.08461007 0.09166091
#> 5: 0.15514528 0.07757264 0.14221651 0.15514528 0.16807406
#> 6: 20.58077984 10.29038992 18.86571486 20.58077984 22.29584483
#> 7: 5.53099447 2.76549724 5.07007827 5.53099447 5.99191068
#> 8: -15.18073482 -7.59036741 -13.91567358 -15.18073482 -16.44579605
#> 9: -8.24223422 -4.12111711 -7.55538137 -8.24223422 -8.92908707
#> 10: -5.34794364 -2.67397182 -4.90228167 -5.34794364 -5.79360561
#> 11: 14.68898157 7.34449078 13.46489977 14.68898157 15.91306337
#> 12: 4.31776592 2.15888296 3.95795210 4.31776592 4.67757975
#> y_hat: 0,0.7 y_hat: 0.5,0.7 y_hat: 0.6,0.7 y_hat: 0.7,0.7
#> 1: -0.39233295 -0.67257078 -0.72861834 -0.78466591
#> 2: -0.16112424 -0.27621299 -0.29923074 -0.32224849
#> 3: 1.09109582 1.87044998 2.02632081 2.18219164
#> 4: 0.04935587 0.08461007 0.09166091 0.09871175
#> 5: 0.09050141 0.15514528 0.16807406 0.18100283
#> 6: 12.00545491 20.58077984 22.29584483 24.01090982
#> 7: 3.22641344 5.53099447 5.99191068 6.45282688
#> 8: -8.85542864 -15.18073482 -16.44579605 -17.71085728
#> 9: -4.80796996 -8.24223422 -8.92908707 -9.61593993
#> 10: -3.11963379 -5.34794364 -5.79360561 -6.23926758
#> 11: 8.56857258 14.68898157 15.91306337 17.13714516
#> 12: 2.51869679 4.31776592 4.67757975 5.03739358
由reprex 包于 2021-03-14 创建(v1.0.0)
如果您更喜欢保留y_hat
为列表列,您可以修改上面的操作
df[, y_hat:=list(.(lapply(seq_len(nrow(paramz_grid)),
function(z) (paramz_grid[z,1] * x) + (paramz_grid[z,2] * x)))),
by=participant][]
``
推荐阅读
- ruby-on-rails - 为什么这个控制器返回一个 html,导致 respond_to 失败?
- python - 传递值的形状是 (3, 1),索引意味着 (3, 3)
- r - 计算 R 的逐年绝对变化
- css - CSS 中资产的绝对路径 URL
- amazon-s3 - 如何在没有 AWS cli 的情况下使用 s3 私有存储桶下载文件
- firebase - 在 Firebase 中用另一个文件替换文件时替换文件的所有 URL 引用 - Flutter 移动应用
- python - mysql数据库问题无法连接python
- python - 使用 softmax 作为输出函数,而使用 binary_crossentropy 作为损失函数?
- python - 使用 np.apply_along_axis 但在某些索引上
- r - 是否可以通过basicPage()而不是dashboardPage,以闪亮的方式生成像infoBox()这样的信息框?