首页 > 解决方案 > 加快多分位数的 Tidyverse 计算

问题描述

我有这个很棒的小功能summarise_posterior(如下所示)作为我的包的一部分driver在此处可用)。

该功能很棒而且超级有用。一个问题是我一直在处理越来越大的数据,而且速度可能非常慢。简而言之,我的问题是:是否有一种 tidyverse-esque 方式可以加快速度,同时仍保留此功能的关键灵活性(请参阅文档中的示例)。

至少一个关键的加速可能来自于弄清楚如何将分位数的计算放在一次调用中,而不是一遍又一遍地调用分位数函数。当前实现的后一个选项可能是一遍又一遍地重新排序相同的向量。

#' Shortcut for summarize variable with quantiles and mean
#'
#' @param data tidy data frame
#' @param var variable name (unquoted) to be summarised
#' @param ... other expressions to pass to summarise
#'
#' @return data.frame
#' @export
#' @details Notation: \code{pX} refers to the \code{X}\% quantile
#' @import dplyr
#' @importFrom stats quantile
#' @importFrom rlang quos quo UQ
#' @examples
#' d <- data.frame("a"=sample(1:10, 50, TRUE),
#'                 "b"=rnorm(50))
#'
#' # Summarize posterior for b over grouping of a and also calcuate
#' # minmum of b (in addition to normal statistics returned)
#' d <- dplyr::group_by(d, a)
#' summarise_posterior(d, b, mean.b = mean(b), min=min(b))
summarise_posterior <- function(data, var, ...){
  qvar <- enquo(var)
  qs <- quos(...)


  data %>%
    summarise(p2.5 = quantile(!!qvar, prob=0.025),
              p25 = quantile(!!qvar, prob=0.25),
              p50 = quantile(!!qvar, prob=0.5),
              mean = mean(!!qvar),
              p75 = quantile(!!qvar, prob=0.75),
              p97.5 = quantile(!!qvar, prob=0.975),
              !!!qs)
}

Rcpp 后端选项也非常受欢迎。

谢谢!

标签: rdplyrtidyverserlang

解决方案


这是一个利用嵌套来避免quantile多次调用的解决方案。任何时候你需要在里面存储一个结果向量summarize,只需将它包装在里面list。之后,您可以取消嵌套这些结果,将它们与它们的名称配对,并用于spread将它们放在单独的列中:

summarise_posterior2 <- function(data, var, ...){
  qvar <- ensym(var)
  vq <- c(0.025, 0.25, 0.5, 0.75, 0.975)

  summarise( data, .qq = list(quantile(!!qvar, vq, names=FALSE)),
             .nms = list(str_c("p", vq*100)), mean = mean(!!qvar), ... ) %>%
  unnest %>% spread( .nms, .qq )  
}

这不会给你带来与@jay.sf 解决方案几乎相同的速度

d <- data.frame("a"=sample(1:10, 5e5, TRUE), "b"=rnorm(5e5))    
microbenchmark::microbenchmark( f1 = summarise_posterior(d, b, mean.b = mean(b), min=min(b)),
                                f2 = summarise_posterior2(d, b, mean.b = mean(b), min=min(b)) )
# Unit: milliseconds
#  expr      min       lq     mean   median       uq      max neval
#    f1 49.06697 50.81422 60.75100 52.43030 54.17242 200.2961   100
#    f2 29.05209 29.66022 32.32508 30.84492 32.56364 138.9579   100

但它可以group_by在嵌套函数和嵌套函数中正常工作(substitute基于 - 的解决方案在嵌套时会中断)

r1 <- d %>% dplyr::group_by(a) %>% summarise_posterior(b, mean.b = mean(b), min=min(b))
r2 <- d %>% dplyr::group_by(a) %>% summarise_posterior2(b, mean.b = mean(b), min=min(b))
all_equal( r1, r2 )     # TRUE

如果你分析代码,你可以看到主要的挂断在哪里

Rprof()
for( i in 1:100 )
  d %>% dplyr::group_by(a) %>% summarise_posterior2(b, mean.b = mean(b), min=min(b))
Rprof(NULL)
summaryRprof()$by.self %>% head
#             self.time self.pct total.time total.pct
# ".Call"          1.84    49.73       3.18     85.95
# "sort.int"       0.94    25.41       1.12     30.27
# "eval"           0.08     2.16       3.64     98.38
# "tryCatch"       0.08     2.16       1.44     38.92
# "anyNA"          0.08     2.16       0.08      2.16
# "structure"      0.04     1.08       0.08      2.16

主要.Call对应的是C++后端dplyr,而sort.int后面是worker quantile()。@jay.sf 的解决方案通过与 解耦获得了显着的加速dplyr,但它也失去了相关的灵活性(例如,与 集成group_by)。最终,由您决定哪个更重要。


推荐阅读