首页 > 解决方案 > 有没有更快的方法在 R 中制作这个混淆矩阵表?

问题描述

我正在尝试使用以下数据框在 R 中制作混淆矩阵表:

mydf <- structure(list(pred_class = c("dog", "dog", "fish", "cat", "cat", 
"dog", "fish", "cat", "dog", "fish"), true_class = c("cat", "cat", 
"dog", "cat", "cat", "dog", "dog", "cat", "dog", "fish")), row.names = c(NA, 
10L), class = "data.frame")

  pred_class true_class
1        dog        cat
2        dog        cat
3       fish        dog
4        cat        cat
5        cat        cat
6        dog        dog

我已经编写了代码来做我想做的事——对于每一类(狗、猫或鱼),说每一行是真阳性、假阳性、真阴性还是假阴性。

conf_mat <- mydf %>%
    mutate(
        dog_conf = case_when(
            true_class == "dog" &  pred_class == "dog" ~ "TP",
            true_class == "dog" &  pred_class != "dog" ~ "FN",
            true_class != "dog" &  pred_class == "dog" ~ "FP",
            true_class != "dog" &  pred_class != "dog" ~ "TN"
        ),
        cat_conf = case_when(
            true_class == "cat" &  pred_class == "cat" ~ "TP",
            true_class == "cat" &  pred_class != "cat" ~ "FN",
            true_class != "cat" &  pred_class == "cat" ~ "FP",
            true_class != "cat" &  pred_class != "cat" ~ "TN"
        ),
        fish_conf = case_when(
            true_class == "fish" &  pred_class == "fish" ~ "TP",
            true_class == "fish" &  pred_class != "fish" ~ "FN",
            true_class != "fish" &  pred_class == "fish" ~ "FP",
            true_class != "fish" &  pred_class != "fish" ~ "TN"
        )
    )

但是,此代码非常重复且庞大。我不确定如何简化这一点。有没有人有什么建议?谢谢你。

标签: r

解决方案


这是一个选项map,我们循环遍历数据集的唯一元素,transmute根据 OP 帖子中指定的条件在循环中创建列,并将这些列与原始数据绑定

library(dplyr)
library(purrr)
library(stringr)

map_dfc(unique(unlist(mydf)), ~ 
      mydf %>% 
           transmute(!! str_c(.x, '_conf') := 
        case_when(true_class == .x &  pred_class == .x ~ "TP",
            true_class == .x &  pred_class != .x ~ "FN",
            true_class != .x &  pred_class == .x ~ "FP",
            true_class != .x &  pred_class != .x ~ "TN"
        ))) %>% 
   bind_cols(mydf, .)

-输出

#     pred_class true_class dog_conf cat_conf fish_conf
#1         dog        cat       FP       FN        TN
#2         dog        cat       FP       FN        TN
#3        fish        dog       FN       TN        FP
#4         cat        cat       TN       TP        TN
#5         cat        cat       TN       TP        TN
#6         dog        dog       TP       TN        TN
#7        fish        dog       FN       TN        FP
#8         cat        cat       TN       TP        TN
#9         dog        dog       TP       TN        TN
#10       fish       fish       TN       TN        TP

或者merge在 key val 数据集上使用

keydat <- data.frame(pred_class = c(TRUE, TRUE, FALSE, FALSE), 
   true_class = c(TRUE, FALSE, TRUE, FALSE), 
  conf = c("TP", "FN", "FP", "TN"))

un1 <- unique(unlist(mydf))
mydf[paste0(un1, "_conf")] <- lapply(un1, function(x)
             merge(mydf == x, keydat, all.x = TRUE)$conf)

推荐阅读