首页 > 解决方案 > 按组有效过滤多个列

问题描述

假设一个数据集包含每个 ID 多行和多列,其中包含一些存储为字符串的代码:

df <- data.frame(id = rep(1:3, each = 2),
                 var1 = c("X1", "Y1", "Y2", "Y3", "Z1", "Z2"),
                 var2 = c("Y1", "X2", "Y2", "Y3", "Z1", "Z2"),
                 var3 = c("Y1", "Y2", "X1", "Y3", "Z1", "Z2"),
                 stringsAsFactors = FALSE)

  id var1 var2 var3
1  1   X1   Y1   Y1
2  1   Y1   X2   Y2
3  2   Y2   Y2   X1
4  2   Y3   Y3   Y3
5  3   Z1   Z1   Z1
6  3   Z2   Z2   Z2

现在,假设我想过滤掉X任何相关列中具有特定代码(此处)的所有 ID。使用dplyrand purrr,我可以这样做:

df %>%
 group_by(id) %>%
 filter(all(reduce(.x = across(var1:var3, ~ !grepl("^X", .)), .f = `&`)))

     id var1  var2  var3 
  <int> <chr> <chr> <chr>
1     3 Z1    Z1    Z1   
2     3 Z2    Z2    Z2 

它工作正常,紧凑且易于理解,但是,对于大型数据集(数百万个 ID 和数千万个观察值),它的效率相当低。我欢迎使用任何库来计算更高效的代码的任何想法。

标签: rregexdataframeperformancefiltering

解决方案


一些可能的速度点

  • 尽量不要使用 group by 之类的东西,即group_byindplyrby = in data.table,因为这会降低你的整体表现
  • 如果您有固定的目标模式,例如,从 开始X,那么substr可能比grepl使用模式更有效^X

一些基本 R 方法

看来我们可以根据@Waldi的最快方法 通过以下方法进一步加快速度

TIC1 <- function() {
    subset(df, ave(rowSums(substr(as.matrix(df[, -1]), 1, 1) == "X") == 0, id, FUN = all))
}

或者

TIC2 <- function() {
    subset(df, !id %in% id[rowSums(substr(as.matrix(df[, -1]), 1, 1) == "X") > 0])
}

或者

TIC3 <- function() {
    subset(df, !id %in% id[do.call(pmax, lapply(df[-1], function(v) substr(v, 1, 1) == "X")) > 0])
}

基准测试

@Waldi@EnricoSchumann的答案相比:

microbenchmark(
    TIC1(),
    TIC2(),
    TIC3(),
    fun1(),
    fun2(),
    waldi_speed(),
    unit = "relative"
)

Unit: relative
          expr       min        lq      mean    median        uq       max
        TIC1()  3.385215  3.451424  3.488670  3.569668  3.684895  3.618991
        TIC2()  1.062116  1.084568  1.074789  1.090400  1.114443  1.027673
        TIC3()  1.077660  2.208734  2.185960  2.214180  2.293366  2.141994
        fun1()  1.166342  1.155096  1.169574  1.153223  1.207932  1.405530
        fun2()  1.000000  1.000000  1.000000  1.000000  1.000000  1.000000
 waldi_speed() 26.218953 26.560429 26.373054 26.952997 27.396017 26.333575
 neval
   100
   100
   100
   100
   100
   100

给定

n <- 5e4
df <- data.frame(
    id = rep(1:(n / 2), each = 2, length.out = n),
    var1 = mapply(paste0, LETTERS[23 + sample(1:3, n, replace = T)], sample(1:3, n, replace = T)),
    var2 = mapply(paste0, LETTERS[23 + sample(1:3, n, replace = T)], sample(1:3, n, replace = T)),
    var3 = mapply(paste0, LETTERS[23 + sample(1:3, n, replace = T)], sample(1:3, n, replace = T)),
    stringsAsFactors = FALSE
)

TIC1 <- function() {
    subset(df, ave(rowSums(substr(as.matrix(df[, -1]), 1, 1) == "X") == 0, id, FUN = all))
}

TIC2 <- function() {
    subset(df, !id %in% id[rowSums(substr(as.matrix(df[, -1]), 1, 1) == "X") > 0])
}

TIC3 <- function() {
    subset(df, !id %in% id[do.call(pmax, lapply(df[-1], function(v) substr(v, 1, 1) == "X")) > 0])
}


waldi_speed <- function() {
    setDT(df)
    df[df[, .(keep = .I[!any(grepl("X", .SD))]), by = id, .SDcols = patterns("var")]$keep]
}


repeated_or <- function(...) {
    L <- list(...)
    ans <- L[[1L]]
    if (...length() > 1L) {
          for (i in seq.int(2L, ...length())) {
                ans <- ans | L[[i]]
            }
      }
    ans
}

fun1 <- function() {
    ## using a pattern
    m <- lapply(df[, -1], grepl, pattern = "^X", perl = TRUE)
    df[!df$id %in% df$id[do.call(repeated_or, m)], ]
}

fun2 <- function() {
    ## using a fixed string
    m <- lapply(df[, -1], function(x) substr(x, 1, 1) == "X")
    df[!df$id %in% df$id[do.call(repeated_or, m)], ]
}

推荐阅读