首页 > 解决方案 > 如何使用 dplyr 获取每行具有最大值的列

问题描述

我在 R 中有一个数据框。对于每一行,我想选择哪一列具有最高值,然后粘贴该列的名称。当只有两列可供选择时,这很简单(请注意,如果两列的值都小于 0.1,我有一个不包括行的过滤步骤):

set.seed(6)
mat_simple <- matrix(rexp(200, rate=.1), ncol=2) %>%
    as.data.frame() 

head(mat_simple)
         V1         V2
1  2.125366  6.7798683
2  1.832349  8.9610534
3  6.149668 15.7777370
4  3.532614  0.2355711
5 21.110703  1.2927119
6  2.871455 16.7370847
    
mat_simple <- mat_simple %>%
    mutate(
        class = case_when(
            V1 < 0.1 & V2 < 0.1 ~ NA_character_,
            V1 > V2 ~ "V1",
            V2 > V1 ~ "V2"
        )
    )

head(mat_simple)
         V1         V2 class
1  2.125366  6.7798683    V2
2  1.832349  8.9610534    V2
3  6.149668 15.7777370    V2
4  3.532614  0.2355711    V1
5 21.110703  1.2927119    V1
6  2.871455 16.7370847    V2

但是,当有两列以上时,这不起作用。例如:

set.seed(6)
mat_hard <- matrix(rexp(200, rate=.1), ncol=5) %>%
     as.data.frame() 

head(mat_hard)
         V1        V2         V3         V4        V5
1  2.125366 26.427335 13.7289349  1.7513873  6.297978
2  1.832349 10.241441  5.3084648  0.3347235 29.247774
3  6.149668  5.689442  5.4546072  4.5035747 11.646721
4  3.532614 10.397464  6.5560545  4.4221171  1.713909
5 21.110703  9.928022  0.2284966  0.2101213  1.033498
6  2.871455  4.781357  3.3246585 15.8878010  4.004967

有没有更好的解决方案,最好使用dplyr

标签: rdataframedplyrrowwise

解决方案


您可以使用以下解决方案:

library(dplyr)

mat_hard %>%
  rowwise() %>%
  mutate(max = names(mat_hard)[c_across(everything()) == max(c_across(everything()))])


# A tibble: 40 x 6
# Rowwise: 
       V1    V2      V3      V4      V5 max  
    <dbl> <dbl>   <dbl>   <dbl>   <dbl> <chr>
 1  2.48   1.73  3.97   24.2    12.7    V4   
 2  9.18   8.86 13.8     9.26    7.64   V3   
 3  6.22   5.96  0.0911  6.66    0.274  V4   
 4  1.14  24.0  12.0     7.37    5.39   V2   
 5  8.09   8.30 24.1     4.01    0.0674 V3   
 6  7.97   2.76  2.21   16.0     0.805  V4   
 7  0.135  2.05  1.85    0.645   2.15   V5   
 8 23.9   31.0   6.48    0.0328 15.1    V2   
 9  9.46   5.66 40.8    12.6     0.320  V3   
10  7.23   8.10  2.06    2.61    6.14   V2   
# ... with 30 more rows

或者您可以通过which.max以下方式使用函数来减少冗长:

mat_hard %>%
  rowwise() %>%
  mutate(max = names(cur_data())[which.max(c_across(everything()))])

# A tibble: 40 x 6
# Rowwise: 
       V1    V2      V3      V4      V5 max  
    <dbl> <dbl>   <dbl>   <dbl>   <dbl> <chr>
 1  2.48   1.73  3.97   24.2    12.7    V4   
 2  9.18   8.86 13.8     9.26    7.64   V3   
 3  6.22   5.96  0.0911  6.66    0.274  V4   
 4  1.14  24.0  12.0     7.37    5.39   V2   
 5  8.09   8.30 24.1     4.01    0.0674 V3   
 6  7.97   2.76  2.21   16.0     0.805  V4   
 7  0.135  2.05  1.85    0.645   2.15   V5   
 8 23.9   31.0   6.48    0.0328 15.1    V2   
 9  9.46   5.66 40.8    12.6     0.320  V3   
10  7.23   8.10  2.06    2.61    6.14   V2   
# ... with 30 more rows

推荐阅读