首页 > 解决方案 > 如何在 tidymodels 框架下提取分类器对单个数据点的预测?

问题描述

我正在做一个文本分类项目,我一直在 tidymodels 框架下做所有事情。现在,我正在尝试调查特定数据点是否一直被错误标记。为此,我想了解单个样本的已保存预测。当我执行重采样并使用 collect_predictions 时,虽然我看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然是隐藏的。有一列可以追溯到(.row),但我无法确认这一点。

我一直在生成我的重采样策略,如下所示:

grades_split <- initial_split(tabled_texts2, strata = grade)

grades_train <- training(grades_split)
grades_test <- testing(grades_split)

folds <- vfold_cv(grades_train)

然后,在调整和拟合模型之后,我生成 resamples 对象:

fitted_grades <- fit(final_wf, grades_train)

LR_rs <- fit_resamples(
  fitted_grades,
  folds,
  control = control_resamples(save_pred = TRUE)
)

最后,我检查这样的预测:

predictions <- collect_predictions(LR_rs)
View(predictions)

我得到一个看起来像这样的表:

ID .pred_4 .pred_not 4 。排 .pred_class 年级 .config
折叠01 0.502905 0.497095 18 4 4 预处理器1_Model1
折叠01 0.484647 0.515353 22 不是 4 4 预处理器1_Model1
折叠01 0.481496 0.518504 23 不是 4 4 预处理器1_Model1
折叠01 0.492314 0.507686 40 不是 4 4 预处理器1_Model1
折叠01 0.477215 0.522785 52 不是 4 4 预处理器1_Model1

如何将这些值映射回原始数据?

这是一个类似的代表。在这个例子中,我希望能够具体看到哪些企鹅被错误分类,而不仅仅是一个任意的 .row 值(我很确定它不会 1-1 映射回原始数据集)

library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)

data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
  set_mode('classification') %>%
  set_engine('naivebayes')

fitted_wf <- workflow() %>%
  add_formula(species ~ island + flipper_length_mm) %>%
  add_model(nb_spec) %>%
  fit(penguins)


split <- initial_split(penguins)

train <- training(split)
test <- testing(split)

folds <- vfold_cv(train)

NB_rs <- fit_resamples(
  fitted_wf,
  folds,
  control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)

标签: rmachine-learningtidymodels

解决方案


实际上,该.row列确实告诉您这些预测中的每一个来自训练数据集中的哪一行。让我们看看我们是否可以说服您:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
set.seed(1)

data("penguins")
nb_spec <- naive_Bayes() %>%
   set_mode('classification') %>%
   set_engine('naivebayes')

fitted_wf <- workflow() %>%
   add_formula(species ~ island + flipper_length_mm) %>%
   add_model(nb_spec) 

split <- penguins %>%
   na.omit() %>%
   initial_split()

penguin_train <- training(split)
penguin_test <- testing(split)

folds <- vfold_cv(penguin_train)

NB_rs <- fit_resamples(
   fitted_wf,
   folds,
   control = control_resamples(save_pred = TRUE)
)

predictions <- collect_predictions(NB_rs)

让我们只看其中一个折叠:

predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#>    id     .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species  
#>    <chr>         <dbl>           <dbl>        <dbl> <int> <fct>       <fct>    
#>  1 Fold01     0.609        0.391        0.000000526     3 Adelie      Adelie   
#>  2 Fold01     0.182        0.818        0.000104        8 Chinstrap   Adelie   
#>  3 Fold01     0.423        0.577        0.000000325     9 Chinstrap   Chinstrap
#>  4 Fold01     0.999        0.00120      0.00000137     21 Adelie      Adelie   
#>  5 Fold01     0.000178     0.0000310    1.00           27 Gentoo      Gentoo   
#>  6 Fold01     0.552        0.448        0.000000395    36 Adelie      Adelie   
#>  7 Fold01     0.997        0.000392     0.00275        45 Adelie      Adelie   
#>  8 Fold01     0.000211     0.000000780  1.00           48 Gentoo      Gentoo   
#>  9 Fold01     0.998        0.00129      0.00114        60 Adelie      Adelie   
#> 10 Fold01     0.00313      0.000100     0.997          79 Gentoo      Gentoo   
#> # … with 15 more rows, and 1 more variable: .config <chr>

这有第 3、8、9 行等。它是第一个重采样的评估folds

现在让我们看一下训练数据:

penguin_train
#> # A tibble: 249 × 7
#>    species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>     <fct>           <dbl>         <dbl>             <int>       <int>
#>  1 Chinstrap Dream            50.2          18.8               202        3800
#>  2 Gentoo    Biscoe           50.2          14.3               218        5700
#>  3 Adelie    Dream            38.1          17.6               187        3425
#>  4 Chinstrap Dream            51            18.8               203        4100
#>  5 Chinstrap Dream            52.7          19.8               197        3725
#>  6 Gentoo    Biscoe           49.6          16                 225        5700
#>  7 Chinstrap Dream            46.2          17.5               187        3650
#>  8 Adelie    Dream            35.7          18                 202        3550
#>  9 Chinstrap Dream            51.7          20.3               194        3775
#> 10 Gentoo    Biscoe           50.4          15.7               222        5750
#> # … with 239 more rows, and 1 more variable: sex <fct>

reprex 包于 2021-07-30 创建 (v2.0.0 )

查看第 3、8 和 9 行;匹配,species因为这些是相同的行!

请注意,对于每个折叠,您可能会得到不同的预测folds,因为它们有不同的训练集,我们称之为分析集。


推荐阅读