首页 > 解决方案 > How to use different engine parameters for each fold with tidymodels during crossvalidation?

问题描述

I would like to use tidymodels to tune a ranger random forest in a crossvalidated setup. My dataset is unbalanced. Because of that I would like to use the ranger parameter class.weights.

However, each fold can have different weights. How can I pass the fold specific weights to the engine?

MWE:

library(tidyverse)
library(tidymodels)
set.seed(111)

iris_cut <- iris[30:110,] # Dummy unbalanced dataset

# Create folds  
folds <- vfold_cv(iris_cut, v = 3)

# Calculate class weights: 
calc_weights <- function(df) {
  weights <- df %>% 
    group_by(Species) %>% 
    mutate(n_total = n()) %>% 
    ungroup() %>% 
    mutate(weight = max(n_total)/n_total) %>% 
    distinct(n_total, .keep_all = TRUE) %>% 
    as.data.frame() %>% 
    .$weight 
  return(weights)
}

# Use during training of fold 1:
weights_fold1 <- folds$splits[[1]]$data[folds$splits[[1]]$in_id,] %>% calc_weights() 
# Use during training of fold 2:
weights_fold2 <- folds$splits[[2]]$data[folds$splits[[2]]$in_id,] %>% calc_weights()
# Use during training of fold 3:
weights_fold3 <- folds$splits[[3]]$data[folds$splits[[3]]$in_id,] %>% calc_weights()


# Defining a recipe
rec <- recipe(Species~ ., data = iris_cut) 
  
# Create Model Specification
rf_mod <- rand_forest(
  mtry = tune(), 
  trees = 1000, 
  min_n = 1
) %>% 
  set_mode("classification") %>% 
  set_engine("ranger",
             class.weights=!!weights_fold1 # Wrong! Here for each fold another weight vector should be passed
             )  

rf_grid <- crossing(
  mtry = c(1,2,3)
)
  
# Setup workflow 
tune_wf <- workflow() %>% 
  add_recipe(rec) %>% 
  add_model(rf_mod)

# Start rf tuning
tune_res <- tune_grid( 
  tune_wf,
  resamples = folds,
  grid = rf_grid
)

标签: rtidymodels

解决方案


The tidymodels ecosystem doesn't have support for passing those kinds of weights to folds. Instead, we encourage folks to use the themis package to handle class imbalance during feature engineering. There are multiple options for upsampling and downsampling.

For example, the SMOTE algorithm:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(themis)
#> Registered S3 methods overwritten by 'themis':
#>   method                  from   
#>   bake.step_downsample    recipes
#>   bake.step_upsample      recipes
#>   prep.step_downsample    recipes
#>   prep.step_upsample      recipes
#>   tidy.step_downsample    recipes
#>   tidy.step_upsample      recipes
#>   tunable.step_downsample recipes
#>   tunable.step_upsample   recipes
#> 
#> Attaching package: 'themis'
#> The following objects are masked from 'package:recipes':
#> 
#>     step_downsample, step_upsample
set.seed(111)

iris_cut <- iris[30:110,] # Dummy unbalanced dataset

iris_rec <- recipe(Species ~ ., data = iris_cut) %>%
  step_smote(Species)
  
iris_rec %>% prep() %>% bake(new_data = NULL)
#> # A tibble: 150 x 5
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#>           <dbl>       <dbl>        <dbl>       <dbl> <fct>  
#>  1          4.7         3.2          1.6         0.2 setosa 
#>  2          4.8         3.1          1.6         0.2 setosa 
#>  3          5.4         3.4          1.5         0.4 setosa 
#>  4          5.2         4.1          1.5         0.1 setosa 
#>  5          5.5         4.2          1.4         0.2 setosa 
#>  6          4.9         3.1          1.5         0.2 setosa 
#>  7          5           3.2          1.2         0.2 setosa 
#>  8          5.5         3.5          1.3         0.2 setosa 
#>  9          4.9         3.6          1.4         0.1 setosa 
#> 10          4.4         3            1.3         0.2 setosa 
#> # … with 140 more rows

Created on 2021-06-25 by the reprex package (v2.0.0)

You can read more about subsampling during modeling here and here.


推荐阅读