首页 > 解决方案 > group_by operation in dplyr vs data.table for fast implementation

问题描述

dat <- data.frame(yearID = rep(1:10000, each = 12),
                  monthID = rep(1:12, times = 10000),
                  x1 = rnorm(120000),
                  x2 = rnorm(120000),
                  x3 = rnorm(120000),
                  x4 = rnorm(120000),
                  x5 = rnorm(120000),
                  x6 = rnorm(120000),
                  p.start = 6,
                  p.end = 7,
                  m.start = 8,
                  m.end = 9,
                  h.start = 10,
                  h.end = 11)

I need to do some operations on the above data which is described below after my current solution

library(dplyr)

start_time <- Sys.time()

df1 <- dat %>% 
       tidyr::gather(., index_name, value, x1:x6) %>%
       dplyr::filter(!index_name %in% c('x5','x6')) %>%
       dplyr::group_by(yearID, index_name) %>%
       dplyr::summarise(p.start.val = sum(value[monthID == p.start]),
                        p.val = sum(value[monthID >= p.start & monthID <= p.end]),
                        m.val = sum(value[monthID >= m.start & monthID <= m.end]),
                        h.val = sum(value[monthID >= h.start & monthID <= h.end]),
                        h.end.val = sum(value[monthID == h.end])) %>%
       tidyr::gather(., variable, value, p.start.val:h.end.val) %>%
       dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
       dplyr::select(-index_name, -variable) %>% 
       tidyr::spread(., new.col.name, value) %>%
       dplyr::mutate(yearRef = 2018)

colnames(df1) <-  sub(".val", "", colnames(df1))    

df2 <- dat %>% 
       tidyr::gather(., index_name, value, x1:x6) %>%
       dplyr::filter(index_name %in% c('x4','x6')) %>%
       dplyr::group_by(yearID, index_name) %>%
       dplyr::summarise(p.end.val = value[monthID == p.end],
                        m.end.val = value[monthID == m.end],
                        h.end.val = value[monthID == h.end]) %>%
       tidyr::gather(., variable, value, p.end.val:h.end.val) %>%
       dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
       dplyr::select(-index_name, -variable) %>% 
       tidyr::spread(., new.col.name, value) %>%
       dplyr::mutate(yearRef = 2018)

colnames(df2) <-  sub(".val", "", colnames(df2))

final.dat <- Reduce(function(...) merge(..., by = c( "yearID", "yearRef"), all.x=TRUE), list(df1,df2))

 end_time <- Sys.time()

 end_time - start_time

 # Time difference of 2.054761 secs

What I want to do is:

My code above works fine but takes quite a time if the size of dat increases i.e. if number of years become 20000 instead of 10000. I am wondering if someone could help me with a data.table to implement the above solution which I hope would make this faster. Thank you.

标签: rdplyrdata.table

解决方案


我将df1只运行它,因为从那里该模式很容易重复。

笔记:

  • magrittr只是用来帮助打破链中的每个步骤,就像每个dplyr*verbs are directly translatable. It is not difficult to convert this into a non-magrittr` 管道。在我看来,使用它的好处(tidyverse 管道也是如此)是可读性和可维护性。

答案

我将逐步完成以下步骤。

library(data.table)
library(magrittr)

as.data.table(dat) %>%
  melt(., measure.vars = grep("^x[0-9]+", colnames(.)),
       variable.name = "index_name", variable.factor = FALSE) %>%
  .[ !index_name %in% c("x5", "x6"), ] %>%
  .[, .(
    p.start.val = sum(value[monthID == p.start]),
    p.val = sum(value[monthID >= p.start & monthID <= p.end]),
    m.val = sum(value[monthID >= m.start & monthID <= m.end]),
    h.val = sum(value[monthID >= h.start & monthID <= h.end]),
    h.end.val = sum(value[monthID == h.end])
  ), by = .(yearID, index_name) ] %>%
  melt(., id.vars = 1:2, variable.factor = FALSE) %>%
  .[, new.col.name := paste0(index_name, "_", variable) ] %>%
  .[, c("index_name", "variable") := NULL ] %>%
  dcast(., yearID ~ new.col.name) %>%
  .[, yearRef := 2018 ]

脚步:

步骤说明:

  • 在演练中,我将dplyr::arrange_all()和添加.[order(.),]到每个中间管道的末尾,以便我们进行逐一比较。

  • 您没有为您的样本添加随机种子。我使用set.seed(42)了 ,因此要将您的控制台与我显示的内容进行比较,您需要设置此种子并重新生成dat

  • 每个代码块都从上一步的代码继续,... %>%为了简洁起见,我将所有重复的代码缩短为使这个答案不那么冗长。

步骤:

  1. tidyr::gatherdata.table::melt. 可能有比grepdata.table::melt.as.data.table(dat)[, -(x1:x6)]melt

    dat %>% 
      tidyr::gather(., index_name, value, x1:x6) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 10
    #   yearID monthID p.start p.end m.start m.end h.start h.end index_name  value
    #    <int>   <int>   <dbl> <dbl>   <dbl> <dbl>   <dbl> <dbl> <chr>       <dbl>
    # 1      1       1       6     7       8     9      10    11 x1          1.37 
    # 2      1       1       6     7       8     9      10    11 x2         -0.483
    # 3      1       1       6     7       8     9      10    11 x3         -0.314
    # 4      1       1       6     7       8     9      10    11 x4         -2.23 
    # 5      1       1       6     7       8     9      10    11 x5         -0.717
    # 6      1       1       6     7       8     9      10    11 x6         -1.04 
    as.data.table(dat) %>%
      melt(., measure.vars = grep("^x[0-9]+", colnames(.)),
           variable.name = "index_name", variable.factor = FALSE) %>%
      .[order(.),] %>% head() # just for comparison
    #    yearID monthID p.start p.end m.start m.end h.start h.end index_name      value
    # 1:      1       1       6     7       8     9      10    11         x1  1.3709584
    # 2:      1       1       6     7       8     9      10    11         x2 -0.4831687
    # 3:      1       1       6     7       8     9      10    11         x3 -0.3139498
    # 4:      1       1       6     7       8     9      10    11         x4 -2.2323282
    # 5:      1       1       6     7       8     9      10    11         x5 -0.7167575
    # 6:      1       1       6     7       8     9      10    11         x6 -1.0357630
    
    
  2. 添加dplyr::filterdplyr::summarise(分组);我实际上只是将新变量的分配从块中复制了出来summarise(...).( ... )没有必要进行任何更改。

    ... %>%
      dplyr::filter(!index_name %in% c('x5','x6')) %>%
      dplyr::group_by(yearID, index_name) %>%
      dplyr::summarise(p.start.val = sum(value[monthID == p.start]),
                       p.val = sum(value[monthID >= p.start & monthID <= p.end]),
                       m.val = sum(value[monthID >= m.start & monthID <= m.end]),
                       h.val = sum(value[monthID >= h.start & monthID <= h.end]),
                       h.end.val = sum(value[monthID == h.end])) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 7
    # # Groups:   yearID [2]
    #   yearID index_name p.start.val  p.val   m.val  h.val h.end.val
    #    <int> <chr>            <dbl>  <dbl>   <dbl>  <dbl>     <dbl>
    # 1      1 x1             -0.106   1.41   1.92    1.24      1.30 
    # 2      1 x2              0.573  -0.516 -2.29   -3.54     -0.990
    # 3      1 x3              0.767   0.455  0.461   2.28      2.08 
    # 4      1 x4             -0.0559 -1.11  -0.0975 -0.326    -0.483
    # 5      2 x1             -2.66   -5.10   1.01   -1.95     -0.172
    # 6      2 x2              0.342  -0.546  0.605   1.51      1.25 
    ... %>%
      .[ !index_name %in% c("x5", "x6"), ] %>%
      .[, .(
        p.start.val = sum(value[monthID == p.start]),
        p.val = sum(value[monthID >= p.start & monthID <= p.end]),
        m.val = sum(value[monthID >= m.start & monthID <= m.end]),
        h.val = sum(value[monthID >= h.start & monthID <= h.end]),
        h.end.val = sum(value[monthID == h.end])
      ), by = .(yearID, index_name) ] %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID index_name p.start.val      p.val       m.val      h.val  h.end.val
    # 1:      1         x1 -0.10612452  1.4053975  1.92376468  1.2421556  1.3048697
    # 2:      1         x2  0.57306337 -0.5164756 -2.28861552 -3.5367198 -0.9901743
    # 3:      1         x3  0.76706512  0.4546020  0.46096277  2.2819246  2.0842981
    # 4:      1         x4 -0.05589648 -1.1093361 -0.09748514 -0.3260778 -0.4825699
    # 5:      2         x1 -2.65645542 -5.0969223  1.01347475 -1.9532258 -0.1719174
    # 6:      2         x2  0.34227065 -0.5457969  0.60537738  1.5136450  1.2498633
    
  3. tidyr::gather再次

    ... %>%
      tidyr::gather(., variable, value, p.start.val:h.end.val) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 4
    # # Groups:   yearID [1]
    #   yearID index_name variable     value
    #    <int> <chr>      <chr>        <dbl>
    # 1      1 x1         h.end.val    1.30 
    # 2      1 x1         h.val        1.24 
    # 3      1 x1         m.val        1.92 
    # 4      1 x1         p.start.val -0.106
    # 5      1 x1         p.val        1.41 
    # 6      1 x2         h.end.val   -0.990
    ... %>%
      melt(., id.vars = 1:2, variable.factor = FALSE) %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID index_name    variable      value
    # 1:      1         x1   h.end.val  1.3048697
    # 2:      1         x1       h.val  1.2421556
    # 3:      1         x1       m.val  1.9237647
    # 4:      1         x1 p.start.val -0.1061245
    # 5:      1         x1       p.val  1.4053975
    # 6:      1         x2   h.end.val -0.9901743
    
  4. tidyr::spreaddata.table::dcast

    ... %>%
      dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
      dplyr::select(-index_name, -variable) %>% 
      tidyr::spread(., new.col.name, value) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 21
    # # Groups:   yearID [6]
    #   yearID x1_h.end.val x1_h.val x1_m.val x1_p.start.val x1_p.val x2_h.end.val x2_h.val x2_m.val x2_p.start.val x2_p.val x3_h.end.val x3_h.val x3_m.val x3_p.start.val x3_p.val x4_h.end.val x4_h.val x4_m.val x4_p.start.val x4_p.val
    #    <int>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>
    # 1      1        1.30     1.24     1.92          -0.106    1.41        -0.990   -3.54   -2.29            0.573   -0.516        2.08     2.28     0.461          0.767    0.455      -0.483   -0.326   -0.0975        -0.0559  -1.11  
    # 2      2       -0.172   -1.95     1.01          -2.66    -5.10         1.25     1.51    0.605           0.342   -0.546       -1.38    -0.731    0.443         -0.725   -1.17       -0.623   -1.91     1.49          -0.806   -0.717 
    # 3      3        0.505   -0.104    1.74          -0.640   -0.185        0.570    1.68   -2.24           -0.103   -1.02        -1.36    -2.50    -0.918          1.36     1.26        0.0847  -0.280    0.699          0.114   -0.582 
    # 4      4       -0.811   -0.379   -2.09          -0.361    0.397       -0.782    0.110  -0.0187         -0.641   -0.149       -1.47    -2.45    -1.27           0.418    0.131       0.0582   0.885    0.784          0.998   -0.0115
    # 5      5       -2.99    -2.90     0.956          0.643    0.733        0.165    0.382   1.46            1.48     2.16        -0.451   -0.213   -0.357          0.222    0.686      -0.949   -0.156    1.23           1.35     0.908 
    # 6      6       -1.04    -0.322    1.96           1.30     1.64         0.838   -0.406   1.86            0.863    2.11         0.479    2.37    -1.13          -1.22    -1.63       -0.970    0.0391  -1.08           0.683   -1.24  
    ... %>%
      .[, new.col.name := paste0(index_name, "_", variable) ] %>%
      .[, c("index_name", "variable") := NULL ] %>%
      dcast(., yearID ~ new.col.name) %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID x1_h.end.val   x1_h.val   x1_m.val x1_p.start.val   x1_p.val x2_h.end.val   x2_h.val    x2_m.val x2_p.start.val   x2_p.val x3_h.end.val   x3_h.val   x3_m.val x3_p.start.val   x3_p.val x4_h.end.val    x4_h.val    x4_m.val x4_p.start.val    x4_p.val
    # 1:      1    1.3048697  1.2421556  1.9237647     -0.1061245  1.4053975   -0.9901743 -3.5367198 -2.28861552      0.5730634 -0.5164756    2.0842981  2.2819246  0.4609628      0.7670651  0.4546020  -0.48256993 -0.32607779 -0.09748514    -0.05589648 -1.10933614
    # 2:      2   -0.1719174 -1.9532258  1.0134748     -2.6564554 -5.0969223    1.2498633  1.5136450  0.60537738      0.3422707 -0.5457969   -1.3790815 -0.7305400  0.4429124     -0.7249950 -1.1681343  -0.62293711 -1.90725766  1.48980773    -0.80634526 -0.71692479
    # 3:      3    0.5049551 -0.1039713  1.7399409     -0.6399949 -0.1845448    0.5697303  1.6768675 -2.24285021     -0.1029872 -1.0245616   -1.3608773 -2.5029906 -0.9178704      1.3641160  1.2619892   0.08468983 -0.27967757  0.69899862     0.11429665 -0.58216791
    # 4:      4   -0.8113932 -0.3785752 -2.0949859     -0.3610573  0.3971059   -0.7823128  0.1098614 -0.01867344     -0.6414615 -0.1488759   -1.4653210 -2.4476336 -1.2718183      0.4179297  0.1311655   0.05823201  0.88484095  0.78382293     0.99795594 -0.01147192
    # 5:      5   -2.9930901 -2.9032572  0.9558396      0.6428993  0.7326600    0.1645109  0.3819658  1.45532687      1.4820236  2.1608213   -0.4513016 -0.2129462 -0.3572757      0.2221201  0.6855960  -0.94859958 -0.15646638  1.23051588     1.34645936  0.90755241
    # 6:      6   -1.0431189 -0.3222408  1.9592347      1.3025426  1.6383908    0.8379162 -0.4059827  1.86142674      0.8626753  2.1076609    0.4792767  2.3683451 -1.1252801     -1.2213407 -1.6339743  -0.96979464  0.03912882 -1.08199221     0.68254513 -1.23950872
    
  5. 完成它

    df1a <- df1 %>% arrange_all()
    head(df1a)
    # # A tibble: 6 x 22
    # # Groups:   yearID [6]
    #   yearID x1_h.end   x1_h   x1_m x1_p.start   x1_p x2_h.end   x2_h    x2_m x2_p.start   x2_p x3_h.end   x3_h   x3_m x3_p.start   x3_p x4_h.end    x4_h    x4_m x4_p.start    x4_p yearRef
    #    <int>    <dbl>  <dbl>  <dbl>      <dbl>  <dbl>    <dbl>  <dbl>   <dbl>      <dbl>  <dbl>    <dbl>  <dbl>  <dbl>      <dbl>  <dbl>    <dbl>   <dbl>   <dbl>      <dbl>   <dbl>   <dbl>
    # 1      1    1.30   1.24   1.92      -0.106  1.41    -0.990 -3.54  -2.29        0.573 -0.516    2.08   2.28   0.461      0.767  0.455  -0.483  -0.326  -0.0975    -0.0559 -1.11      2018
    # 2      2   -0.172 -1.95   1.01      -2.66  -5.10     1.25   1.51   0.605       0.342 -0.546   -1.38  -0.731  0.443     -0.725 -1.17   -0.623  -1.91    1.49      -0.806  -0.717     2018
    # 3      3    0.505 -0.104  1.74      -0.640 -0.185    0.570  1.68  -2.24       -0.103 -1.02    -1.36  -2.50  -0.918      1.36   1.26    0.0847 -0.280   0.699      0.114  -0.582     2018
    # 4      4   -0.811 -0.379 -2.09      -0.361  0.397   -0.782  0.110 -0.0187     -0.641 -0.149   -1.47  -2.45  -1.27       0.418  0.131   0.0582  0.885   0.784      0.998  -0.0115    2018
    # 5      5   -2.99  -2.90   0.956      0.643  0.733    0.165  0.382  1.46        1.48   2.16    -0.451 -0.213 -0.357      0.222  0.686  -0.949  -0.156   1.23       1.35    0.908     2018
    # 6      6   -1.04  -0.322  1.96       1.30   1.64     0.838 -0.406  1.86        0.863  2.11     0.479  2.37  -1.13      -1.22  -1.63   -0.970   0.0391 -1.08       0.683  -1.24      2018
    df1b <- ... %>%
      .[, yearRef := 2018 ] %>%
      .[order(.),]
    head(df1b)
    #    yearID x1_h.end.val   x1_h.val   x1_m.val x1_p.start.val   x1_p.val x2_h.end.val   x2_h.val    x2_m.val x2_p.start.val   x2_p.val x3_h.end.val   x3_h.val   x3_m.val x3_p.start.val   x3_p.val x4_h.end.val    x4_h.val    x4_m.val x4_p.start.val    x4_p.val yearRef
    # 1:      1    1.3048697  1.2421556  1.9237647     -0.1061245  1.4053975   -0.9901743 -3.5367198 -2.28861552      0.5730634 -0.5164756    2.0842981  2.2819246  0.4609628      0.7670651  0.4546020  -0.48256993 -0.32607779 -0.09748514    -0.05589648 -1.10933614    2018
    # 2:      2   -0.1719174 -1.9532258  1.0134748     -2.6564554 -5.0969223    1.2498633  1.5136450  0.60537738      0.3422707 -0.5457969   -1.3790815 -0.7305400  0.4429124     -0.7249950 -1.1681343  -0.62293711 -1.90725766  1.48980773    -0.80634526 -0.71692479    2018
    # 3:      3    0.5049551 -0.1039713  1.7399409     -0.6399949 -0.1845448    0.5697303  1.6768675 -2.24285021     -0.1029872 -1.0245616   -1.3608773 -2.5029906 -0.9178704      1.3641160  1.2619892   0.08468983 -0.27967757  0.69899862     0.11429665 -0.58216791    2018
    # 4:      4   -0.8113932 -0.3785752 -2.0949859     -0.3610573  0.3971059   -0.7823128  0.1098614 -0.01867344     -0.6414615 -0.1488759   -1.4653210 -2.4476336 -1.2718183      0.4179297  0.1311655   0.05823201  0.88484095  0.78382293     0.99795594 -0.01147192    2018
    # 5:      5   -2.9930901 -2.9032572  0.9558396      0.6428993  0.7326600    0.1645109  0.3819658  1.45532687      1.4820236  2.1608213   -0.4513016 -0.2129462 -0.3572757      0.2221201  0.6855960  -0.94859958 -0.15646638  1.23051588     1.34645936  0.90755241    2018
    # 6:      6   -1.0431189 -0.3222408  1.9592347      1.3025426  1.6383908    0.8379162 -0.4059827  1.86142674      0.8626753  2.1076609    0.4792767  2.3683451 -1.1252801     -1.2213407 -1.6339743  -0.96979464  0.03912882 -1.08199221     0.68254513 -1.23950872    2018
    

他们确实匹配:

identical(as.data.frame(df1a), as.data.frame(df1b))
# [1] TRUE

加速并不巨大,但它们似乎确实很重要。您可以加快自己的代码(仍然)的一种方法是在不需要时立即dplyr删除ing。group如果我在ungroup()之后立即添加summarise(...),我会看到一个小的改进。

microbenchmark::microbenchmark(
  dplyr = { ... },
  dplyr_ungrp = { ... },
  data.table = { ... },
  times = 10
)
# Unit: milliseconds
#         expr      min        lq      mean    median        uq       max neval
#        dplyr 988.8311 1021.4725 1048.5462 1045.6885 1066.2733 1135.6032    10
#  dplyr_ungrp 909.3643  913.9301  952.6282  937.6540  998.2802 1041.2144    10
#   data.table 457.4500  465.1788  478.1471  474.2388  478.9840  531.1449    10

推荐阅读