I would like to use tidymodels to tune a ranger random forest in a crossvalidated setup. My dataset is unbalanced. Because of that I would like to use the ranger parameter class.weights.

However, each fold can have different weights. How can I pass the fold specific weights to the engine?



iris_cut <- iris[30:110,] # Dummy unbalanced dataset

# Create folds  
folds <- vfold_cv(iris_cut, v = 3)

# Calculate class weights: 
calc_weights <- function(df) {
  weights <- df %>% 
    group_by(Species) %>% 
    mutate(n_total = n()) %>% 
    ungroup() %>% 
    mutate(weight = max(n_total)/n_total) %>% 
    distinct(n_total, .keep_all = TRUE) %>% 
    as.data.frame() %>% 

# Use during training of fold 1:
weights_fold1 <- folds$splits[[1]]$data[folds$splits[[1]]$in_id,] %>% calc_weights() 
# Use during training of fold 2:
weights_fold2 <- folds$splits[[2]]$data[folds$splits[[2]]$in_id,] %>% calc_weights()
# Use during training of fold 3:
weights_fold3 <- folds$splits[[3]]$data[folds$splits[[3]]$in_id,] %>% calc_weights()

# Defining a recipe
rec <- recipe(Species~ ., data = iris_cut) 
# Create Model Specification
rf_mod <- rand_forest(
  mtry = tune(), 
  trees = 1000, 
  min_n = 1
) %>% 
  set_mode("classification") %>% 
             class.weights=!!weights_fold1 # Wrong! Here for each fold another weight vector should be passed

rf_grid <- crossing(
  mtry = c(1,2,3)
# Setup workflow 
tune_wf <- workflow() %>% 
  add_recipe(rec) %>% 

# Start rf tuning
tune_res <- tune_grid( 
  resamples = folds,
  grid = rf_grid

The tidymodels ecosystem doesn't have support for passing those kinds of weights to folds. Instead, we encourage folks to use the themis package to handle class imbalance during feature engineering. There are multiple options for upsampling and downsampling.

For example, the SMOTE algorithm:

iris_cut <- iris[30:110,] # Dummy unbalanced dataset

iris_rec <- recipe(Species ~ ., data = iris_cut) %>%
iris_rec %>% prep() %>% bake(new_data = NULL)
#> # A tibble: 150 x 5
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#>           <dbl>       <dbl>        <dbl>       <dbl> <fct>  
#>  1          4.7         3.2          1.6         0.2 setosa 
#>  2          4.8         3.1          1.6         0.2 setosa 
#>  3          5.4         3.4          1.5         0.4 setosa 
#>  4          5.2         4.1          1.5         0.1 setosa 
#>  5          5.5         4.2          1.4         0.2 setosa 
#>  6          4.9         3.1          1.5         0.2 setosa 
#>  7          5           3.2          1.2         0.2 setosa 
#>  8          5.5         3.5          1.3         0.2 setosa 
#>  9          4.9         3.6          1.4         0.1 setosa 
#> 10          4.4         3            1.3         0.2 setosa 
#> # … with 140 more rows

You can read more about subsampling during modeling here and here.
