r - 如何在 tidymodels 框架下提取分类器对单个数据点的预测?
问题描述
我正在做一个文本分类项目,我一直在 tidymodels 框架下做所有事情。现在,我正在尝试调查特定数据点是否一直被错误标记。为此,我想了解单个样本的已保存预测。当我执行重采样并使用 collect_predictions 时,虽然我看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然是隐藏的。有一列可以追溯到(.row),但我无法确认这一点。
我一直在生成我的重采样策略,如下所示:
grades_split <- initial_split(tabled_texts2, strata = grade)
grades_train <- training(grades_split)
grades_test <- testing(grades_split)
folds <- vfold_cv(grades_train)
然后,在调整和拟合模型之后,我生成 resamples 对象:
fitted_grades <- fit(final_wf, grades_train)
LR_rs <- fit_resamples(
fitted_grades,
folds,
control = control_resamples(save_pred = TRUE)
)
最后,我检查这样的预测:
predictions <- collect_predictions(LR_rs)
View(predictions)
我得到一个看起来像这样的表:
ID | .pred_4 | .pred_not 4 | 。排 | .pred_class | 年级 | .config |
---|---|---|---|---|---|---|
折叠01 | 0.502905 | 0.497095 | 18 | 4 | 4 | 预处理器1_Model1 |
折叠01 | 0.484647 | 0.515353 | 22 | 不是 4 | 4 | 预处理器1_Model1 |
折叠01 | 0.481496 | 0.518504 | 23 | 不是 4 | 4 | 预处理器1_Model1 |
折叠01 | 0.492314 | 0.507686 | 40 | 不是 4 | 4 | 预处理器1_Model1 |
折叠01 | 0.477215 | 0.522785 | 52 | 不是 4 | 4 | 预处理器1_Model1 |
如何将这些值映射回原始数据?
这是一个类似的代表。在这个例子中,我希望能够具体看到哪些企鹅被错误分类,而不仅仅是一个任意的 .row 值(我很确定它不会 1-1 映射回原始数据集)
library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)
data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec) %>%
fit(penguins)
split <- initial_split(penguins)
train <- training(split)
test <- testing(split)
folds <- vfold_cv(train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)
解决方案
实际上,该.row
列确实告诉您这些预测中的每一个来自训练数据集中的哪一行。让我们看看我们是否可以说服您:
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(discrim)
#>
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#>
#> smoothness
set.seed(1)
data("penguins")
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec)
split <- penguins %>%
na.omit() %>%
initial_split()
penguin_train <- training(split)
penguin_test <- testing(split)
folds <- vfold_cv(penguin_train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
让我们只看其中一个折叠:
predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#> id .pred_Adelie .pred_Chinstrap .pred_Gentoo .row .pred_class species
#> <chr> <dbl> <dbl> <dbl> <int> <fct> <fct>
#> 1 Fold01 0.609 0.391 0.000000526 3 Adelie Adelie
#> 2 Fold01 0.182 0.818 0.000104 8 Chinstrap Adelie
#> 3 Fold01 0.423 0.577 0.000000325 9 Chinstrap Chinstrap
#> 4 Fold01 0.999 0.00120 0.00000137 21 Adelie Adelie
#> 5 Fold01 0.000178 0.0000310 1.00 27 Gentoo Gentoo
#> 6 Fold01 0.552 0.448 0.000000395 36 Adelie Adelie
#> 7 Fold01 0.997 0.000392 0.00275 45 Adelie Adelie
#> 8 Fold01 0.000211 0.000000780 1.00 48 Gentoo Gentoo
#> 9 Fold01 0.998 0.00129 0.00114 60 Adelie Adelie
#> 10 Fold01 0.00313 0.000100 0.997 79 Gentoo Gentoo
#> # … with 15 more rows, and 1 more variable: .config <chr>
这有第 3、8、9 行等。它是第一个重采样的评估集folds
。
现在让我们看一下训练数据:
penguin_train
#> # A tibble: 249 × 7
#> species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#> <fct> <fct> <dbl> <dbl> <int> <int>
#> 1 Chinstrap Dream 50.2 18.8 202 3800
#> 2 Gentoo Biscoe 50.2 14.3 218 5700
#> 3 Adelie Dream 38.1 17.6 187 3425
#> 4 Chinstrap Dream 51 18.8 203 4100
#> 5 Chinstrap Dream 52.7 19.8 197 3725
#> 6 Gentoo Biscoe 49.6 16 225 5700
#> 7 Chinstrap Dream 46.2 17.5 187 3650
#> 8 Adelie Dream 35.7 18 202 3550
#> 9 Chinstrap Dream 51.7 20.3 194 3775
#> 10 Gentoo Biscoe 50.4 15.7 222 5750
#> # … with 239 more rows, and 1 more variable: sex <fct>
由reprex 包于 2021-07-30 创建 (v2.0.0 )
查看第 3、8 和 9 行;匹配,species
因为这些是相同的行!
请注意,对于每个折叠,您可能会得到不同的预测folds
,因为它们有不同的训练集,我们称之为分析集。
推荐阅读
- flutter - 监听颤振的可变变化
- c# - 在第二次迭代中,内部 foreach 循环不会从第二个元素继续
- postgresql - 将列数据拆分为两列并将其插入到 postgresql 中的现有表中
- php - 当在 laravel 中使用 ajax 从下拉列表中选择产品时,如何创建显示产品价格的标签
- python - 从另一个文本文件提供 SFTP 服务器凭据
- python - 有没有一种方法可以从 python 中的多个文本文件中提取多条数据并将其保存为新的 .csv 文件中的一行?
- angular - ionic 4: dateStart 和 dateEnd 有 2 小时的差异
- docker - 如何为我的项目的多个版本共享一个 Dockerfile?
- android - 如何根据导航组件中的片段内的某些条件隐藏向上导航图标?
- single-sign-on - ADFS 登录给出错误 Microsoft.IdentityServer.Web.UnsupportedSamlRequestException