首页 > 解决方案 > R data.table:通过键和组通过多列与data.table相交的最快方法是什么

问题描述

主要编辑澄清答案是错误的

我有一个包含组列(split_by)、键列(key_by)和特征 ids 列(intersect_by)的 data.table

我希望在每组 split_by 中,只保留组中所有当前键共享特征 ID 的行。

例如:

dt <- data.table(id = 1:6, key1 = 1, key2 = c(1:2, 2), group_id1= 1, group_id2= c(1:2, 2:1, 1:2), trait_id1 = 1, trait_id2 = 2:1)
setkey(dt, group_id1, group_id2, trait_id1, trait_id2)
dt
   id key1 key2 group_id1 group_id2 trait_id1 trait_id2
1:  4    1    1         1         1         1         1
2:  1    1    1         1         1         1         2
3:  5    1    2         1         1         1         2
4:  2    1    2         1         2         1         1
5:  6    1    2         1         2         1         1
6:  3    1    2         1         2         1         2

res <- intersect_this_by(dt,
                         key_by = c("key1"),
                         split_by = c("group_id1", "group_id2"),
                         intersect_by = c("trait_id1", "trait_id2"))

我希望 res 是这样的:

> res[]
   id key1 key2 group_id1 group_id2 trait_id1 trait_id2
1:  1    1    1         1         1         1         2
2:  5    1    2         1         1         1         2
3:  2    1    2         1         2         1         1
4:  6    1    2         1         2         1         1
5:  3    1    2         1         2         1         2

我们看到 id 4 已被删除,如 group_id1 = 1 和 group_id2 = 1 组合组(id 4 所属的组)只有一个键组合 (1,1) 具有这些特征 (1,1) 而有该组中的两个键组合:(1,1) 和 (1,2),因此该组中的所有键不共享特征 (1,1),因此我们从该组中删除此特征,因此删除 id 4。相反,id 1 和 5 具有相同的特征,但键不同,它们代表该组中的所有键((1,1)和(1,2)),因此保留了 id 1 和 5 的特征。

那里给出了实现这一点的功能:

intersect_this_by2 <- function(dt,
                               key_by = NULL,
                               split_by = NULL,
                               intersect_by = NULL){

    dtc <- as.data.table(dt)       

    # compute number of keys in the group
    dtc[, n_keys := uniqueN(.SD), by = split_by, .SDcols = key_by]
    # compute number of keys represented by each trait in each group 
    # and keep row only if they represent all keys from the group
    dtc[, keep := n_keys == uniqueN(.SD), by = c(intersect_by, split_by), .SDcols = key_by]
    dtc <- dtc[keep == TRUE][, c("n_keys", "keep") := NULL]
    return(dtc)      
}

但是对于大数据集或复杂的特征/键/组来说,它变得相当慢......真正的 data.table 有 1000 万行,特征有 30 个级别......有没有办法改进它?有什么明显的陷阱吗?谢谢您的帮助

最终编辑: Uwe 提出了一个简洁的解决方案,它比我的初始代码快 40%(我在这里删除,因为它很混乱)最终函数如下所示:

intersect_this_by_uwe <- function(dt,
                                  key_by = c("key1"),
                                  split_by = c("group_id1", "group_id2"),
                                  intersect_by = c("trait_id1", "trait_id2")){
    dti <- copy(dt)
    dti[, original_order_id__ := 1:.N]
    setkeyv(dti, c(split_by, intersect_by, key_by))
    uni <- unique(dti, by = c(split_by, intersect_by, key_by))
    unique_keys_by_group <-
        unique(uni, by = c(split_by, key_by))[, .N, by = c(split_by)]
    unique_keys_by_group_and_trait <-
        uni[, .N, by = c(split_by, intersect_by)]
    # 1st join to pick group/traits combinations with equal number of unique keys
    selected_groups_and_traits <-
        unique_keys_by_group_and_trait[unique_keys_by_group,
                                       on = c(split_by, "N"), nomatch = 0L]
    # 2nd join to pick records of valid subsets
    dti[selected_groups_and_traits, on = c(split_by, intersect_by)][
        order(original_order_id__), -c("original_order_id__","N")]
}

对于记录,10M 行数据集的基准:

> microbenchmark::microbenchmark(old_way = {res <- intersect_this_by(dt,
+                                                                    key_by = c("key1"),
+                                                                    split_by = c("group_id1", "group_id2"),
+                                                                    intersect_by = c("trait_id1", "trait_id2"))},
+                                new_way = {res <- intersect_this_by2(dt,
+                                                                     key_by = c("key1"),
+                                                                     split_by = c("group_id1", "group_id2"),
+                                                                     intersect_by = c("trait_id1", "trait_id2"))},
+                                new_way_uwe = {res <- intersect_this_by_uwe(dt,
+                                                                            key_by = c("key1"),
+                                                                            split_by = c("group_id1", "group_id2"),
+                                                                            intersect_by = c("trait_id1", "trait_id2"))},
+                                times = 10)
Unit: seconds
        expr       min        lq      mean    median        uq       max neval cld
     old_way  3.145468  3.530898  3.514020  3.544661  3.577814  3.623707    10  b 
     new_way 15.670487 15.792249 15.948385 15.988003 16.097436 16.206044    10   c
 new_way_uwe  1.982503  2.350001  2.320591  2.394206  2.412751  2.436381    10 a  

标签: rmergedata.table

解决方案


编辑

尽管下面的答案确实重现了小样本数据集的预期结果,但它未能为OP 提供的大型 10 M 行数据集给出正确答案。

但是,我决定保留这个错误的答案,因为基准测试结果显示该uniqueN()功能的性能不佳。此外,答案包含更快的替代解决方案的基准。



如果我理解正确,OP 只想保留 , , 和 的唯一组合group_id1出现group_id2trait_id1多个trait_id2不同的key1.

这可以通过计算 、 、 和 的每组中的唯一值来实现,并仅选择 、 、key1和的那些组合并且其中计数大于 1。最后,通过加入来检索匹配的行:group_id1group_id2trait_id1trait_id2group_id1group_id2trait_id1trait_id2

library(data.table)
sel <- dt[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]
sel
   group_id1 group_id2 trait_id1 trait_id2 V1
1:         1         2         3         1  2
2:         2         2         2         1  2
3:         2         1         1         2  2
4:         1         1         1         1  2
5:         1         1         2         2  2
6:         2         2         2         2  2
7:         1         1         1         2  2
8:         1         1         3         2  2
res <- dt[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][order(id), -"V1"]
res
    id key1 group_id1 trait_id1 group_id2 trait_id2 extra
 1:  1    2         1         3         2         1     u
 2:  2    1         2         2         2         1     g
 3:  5    2         2         1         1         2     g
 4:  8    2         1         3         2         1     o
 5:  9    2         1         1         1         1     d
 6: 10    2         2         1         1         2     g
 7: 13    1         2         1         1         2     c
 8: 14    2         1         2         1         2     t
 9: 15    1         1         3         2         1     y
10: 16    2         1         3         2         1     v
11: 19    2         2         2         2         2     y
12: 22    2         2         2         2         1     g
13: 24    2         1         1         1         2     i
14: 25    1         1         3         1         2     n
15: 26    1         2         2         2         2     y
16: 27    1         1         1         1         1     n
17: 28    1         1         1         1         2     h
18: 29    1         2         2         2         2     b
19: 30    2         1         3         1         2     k
20: 31    1         2         2         2         2     w
21: 35    1         1         2         1         2     q
22: 37    2         2         1         1         2     r
23: 39    1         1         1         1         2     o
    id key1 group_id1 trait_id1 group_id2 trait_id2 extra

这重现了 OP 的预期结果,但它也是 OP 要求的最快方式吗?


基准测试第 1 部分

此处使用OP创建基准数据的代码(但使用 1 M 行而不是 10 M 行):

set.seed(0)
n <- 1e6
p <- 1e5
m <- 5
dt <- data.table(id = 1:n,
                 key1 = sample(1:m, size = n, replace = TRUE),
                 group_id1 = sample(1:2, size = n, replace = TRUE),
                 trait_id1 = sample(1:p, size = n, replace = TRUE),
                 group_id2 = sample(1:2, size = n, replace = TRUE),
                 trait_id2 = sample(1:2, size = n, replace = TRUE),
                 extra = sample(letters, n, replace = TRUE))

我很惊讶地发现使用的解决方案uniqueN()不是最快的:

Unit: milliseconds
    expr       min        lq      mean    median        uq       max neval cld
 old_way  489.4606  496.3801  523.3361  503.2997  540.2739  577.2482     3 a  
 new_way 9356.4131 9444.5698 9567.4035 9532.7265 9672.8987 9813.0710     3   c
    uwe1 5946.4533 5996.7388 6016.8266 6047.0243 6052.0133 6057.0023     3  b

基准代码:

microbenchmark::microbenchmark(
  old_way = {
    DT <- copy(dt)
    res <- intersect_this_by(DT,
                             key_by = c("key1"),
                             split_by = c("group_id1", "group_id2"),
                             intersect_by = c("trait_id1", "trait_id2"))
  },
  new_way = {
    DT <- copy(dt)
    res <- intersect_this_by2(DT,
                              key_by = c("key1"),
                              split_by = c("group_id1", "group_id2"),
                              intersect_by = c("trait_id1", "trait_id2"))
  },
  uwe1 = {
    DT <- copy(dt)
    sel <- DT[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  times = 3L)

请注意,每次运行都会使用基准数据的新副本,以避免先前运行产生的任何副作用,例如,由 设置的索引data.table

开启详细模式

options(datatable.verbose = TRUE)

揭示了大部分时间都花在了uniqueN()所有组的计算上:

sel <- DT[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]

Detected that j uses these columns: key1 
Finding groups using forderv ... 0.060sec 
Finding group sizes from the positions (can be avoided to save RAM) ... 0.000sec 
Getting back original order ... 0.050sec 
lapply optimization is on, j unchanged as 'uniqueN(key1)'
GForce is on, left j unchanged
Old mean optimization is on, left j unchanged.
Making each group and running j (GForce FALSE) ... 
  collecting discontiguous groups took 0.084s for 570942 groups
  eval(j) took 5.505s for 570942 calls
5.940sec

这是一个已知问题。然而,替代方案lenght(unique())(它uniqueN()是一个缩写)只带来了 2 的适度加速。

所以我开始寻找避免uniqueN()或的方法lenght(unique())


基准测试第 2 部分

我发现了两个足够快的替代方案。两者都创建了group_id1, group_id2, trait_id1,的唯一组合的 data.table trait_id2并且 key1在第一步中,计算key1每组group_id1, , group_id2,trait_id1的不同值的数量trait_id2,并过滤大于一个的计数:

sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
  , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]

sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
  , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]

详细的输出表明这些变体的计算时间明显更好。

对于基准测试,只使用最快的方法,但现在有 10 M 行。此外,每个变体都尝试了setkey()setorder(),分别预先应用:

microbenchmark::microbenchmark(
  old_way = {
    DT <- copy(dt)
    res <- intersect_this_by(DT,
                             key_by = c("key1"),
                             split_by = c("group_id1", "group_id2"),
                             intersect_by = c("trait_id1", "trait_id2"))
  },
  uwe3 = {
    DT <- copy(dt)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe3k = {
    DT <- copy(dt)
    setkey(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe3o = {
    DT <- copy(dt)
    setorder(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4 = {
    DT <- copy(dt)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4k = {
    DT <- copy(dt)
    setkey(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4o = {
    DT <- copy(dt)
    setorder(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  times = 3L)

10 M 案例的基准测试结果表明,两种变体都比 OP 的intersect_this_by()功能更快,并且键控和排序正在推动加速(排序优势很小)。

Unit: seconds
    expr      min       lq     mean   median       uq      max neval  cld
 old_way 7.173517 7.198064 7.256211 7.222612 7.297559 7.372506     3    d
    uwe3 6.820324 6.833151 6.878777 6.845978 6.908003 6.970029     3   c 
   uwe3k 5.349949 5.412018 5.436806 5.474086 5.480234 5.486381     3 a   
   uwe3o 5.423440 5.432562 5.467376 5.441683 5.489344 5.537006     3 a   
    uwe4 6.270724 6.276757 6.301774 6.282790 6.317299 6.351807     3  b  
   uwe4k 5.280763 5.295251 5.418803 5.309739 5.487823 5.665906     3 a   
   uwe4o 4.921627 5.095762 5.157010 5.269898 5.274702 5.279506     3 a

推荐阅读