首页 > 解决方案 > 使用 TIdymodels 调优 XGBoost - 12 小时后仍未完成

问题描述

我已经在高性能机器(4Ghz,16 核,32gb RAM)上运行 R 中的 XGBoost 模型超过 12 小时,但仍未完成。我不确定出了什么问题。我跟着Julia Silge 的博客来到了发球台。这是我的数据的样子:

str(hts.facility.df)
tibble [24,422 x 47] (S3: tbl_df/tbl/data.frame)
$ patient_id                            : Factor w/ 24422 levels 
$ datim_code                            : chr [1:24422] 
$ sex                                   : Factor w/ 2 levels "F","M": 2 1 1 1 1 1 1 1 2 1 ...
$ age                                   : num [1:24422] 33 36 29 21 49 44 71 26 50 38 ...
$ age_group                             : Factor w/ 12 levels "< 1","1 - 4",..: 7 8 6 5 10 9 12 6 12 8 ...
$ referred_from                         : Factor w/ 3 levels "Self","TB","Other": 2 1 1 1 1 1 1 1 1 1 ...
$ marital_status                        : Factor w/ 4 levels "M","S","W","D": 1 1 2 1 1 1 3 2 1 2 ...
$ no_of_own_children_lessthan_5         : Factor w/ 2 levels "0","more_than_2_children": 2 1 1 1 1 1 1 1 1 1 ...
$ no_of_wives                           : Factor w/ 2 levels "0","more_than_2_wives": 2 1 1 1 1 1 1 1 1 1 ...
$ session_type                          : Factor w/ 2 levels "Couple","Individual": 2 2 2 2 2 2 2 2 2 2 ...
$ previously_tested_hiv_negative        : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...
$ client_pregnant                       : Factor w/ 2 levels "1","0": 2 1 1 1 1 1 1 1 2 1 ...
$ hts_test_result                       : Factor w/ 2 levels "Neg","Pos": 1 1 1 1 1 1 1 1 1 1 ...
$ hts_setting                           : Factor w/ 4 levels "CT","TB","Ward",..: 3 3 3 3 3 3 3 3 3 3 ...
$ tested_for_hiv_before_within_this_year: Factor w/ 2 levels "PreviouslyTestedNegative",..: 2 1 2 2 2 2 2 2 2 2 ...
$ is_surge_site                         : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 1 1 1 ...
$ nga_agesex_f_15_2019                  : num [1:24422] 0.0627 0.0627 0.0627 0.0627 0.0627 ...
$ nga_agesex_f_20_2019                  : num [1:24422] 0.0581 0.0581 0.0581 0.0581 0.0581 ...
$ nga_agesex_f_25_2019                  : num [1:24422] 0.0411 0.0411 0.0411 0.0411 0.0411 ...
$ nga_agesex_f_30_2019                  : num [1:24422] 0.0314 0.0314 0.0314 0.0314 0.0314 ...
$ nga_agesex_f_35_2019                  : num [1:24422] 0.0275 0.0275 0.0275 0.0275 0.0275 ...
$ nga_agesex_f_40_2019                  : num [1:24422] 0.021 0.021 0.021 0.021 0.021 ...
$ nga_agesex_f_45_2019                  : num [1:24422] 0.0166 0.0166 0.0166 0.0166 0.0166 ...
$ nga_agesex_m_15_2019                  : num [1:24422] 0.0536 0.0536 0.0536 0.0536 0.0536 ...
$ nga_agesex_m_20_2019                  : num [1:24422] 0.0632 0.0632 0.0632 0.0632 0.0632 ...
$ nga_agesex_m_25_2019                  : num [1:24422] 0.0534 0.0534 0.0534 0.0534 0.0534 ...
$ nga_agesex_m_30_2019                  : num [1:24422] 0.036 0.036 0.036 0.036 0.036 ...
$ nga_agesex_m_35_2019                  : num [1:24422] 0.0325 0.0325 0.0325 0.0325 0.0325 ...
$ nga_agesex_m_40_2019                  : num [1:24422] 0.0263 0.0263 0.0263 0.0263 0.0263 ...
$ nga_agesex_m_45_2019                  : num [1:24422] 0.0236 0.0236 0.0236 0.0236 0.0236 ...
$ IHME_CONDOM_LAST_TIME_PREV_MEAN_2017  : num [1:24422] 14.1 14.1 14.1 14.1 14.1 ...
$ IHME_HAD_INTERCOURSE_PREV_MEAN_2017   : num [1:24422] 63.1 63.1 63.1 63.1 63.1 ...
$ IHME_HIV_COUNT_MEAN_2017              : num [1:24422] 0.0126 0.0126 0.0126 0.0126 0.0126 ...
$ IHME_IN_UNION_PREV_MEAN_2017          : num [1:24422] 56.9 56.9 56.9 56.9 56.9 ...
$ IHME_MALE_CIRCUMCISION_PREV_MEAN_2017 : num [1:24422] 98.7 98.7 98.7 98.7 98.7 ...
$ IHME_PARTNER_AWAY_PREV_MEAN_2017      : num [1:24422] 13.5 13.5 13.5 13.5 13.5 ...
$ IHME_PARTNERS_YEAR_MN_PREV_MEAN_2017  : num [1:24422] 13.5 13.5 13.5 13.5 13.5 ...
$ IHME_PARTNERS_YEAR_WN_PREV_MEAN_2017  : num [1:24422] 3.07 3.07 3.07 3.07 3.07 ...
$ IHME_STI_SYMPTOMS_PREV_MEAN_2017      : num [1:24422] 4.15 4.15 4.15 4.15 4.15 ...
$ wp_contraceptive                      : num [1:24422] 0.282 0.282 0.282 0.282 0.282 ...
$ wp_liveBirths                         : num [1:24422] 124 124 124 124 124 ...
$ wp_poverty                            : num [1:24422] 0.555 0.555 0.555 0.555 0.555 ...
$ wp_lit_men                            : num [1:24422] 0.967 0.967 0.967 0.967 0.967 ...
$ wp_lit_women                          : num [1:24422] 0.874 0.874 0.874 0.874 0.874 ...
$ wp_stunting_men                       : num [1:24422] 0.178 0.178 0.178 0.178 0.178 ...
$ wp_stunting_women                     : num [1:24422] 0.215 0.215 0.215 0.215 0.215 ...
$ road_density_km                       : num [1:24422] 82.3 82.3 82.3 82.3 82.3 ... 

这是我正在运行的代码:

set.seed(4488)

hts.facility.df2 = hts.facility.df %>% 
  mutate(hts_test_result = as.factor(case_when(
    hts_test_result == 'Pos' ~ 1,
    hts_test_result == 'Neg' ~ 0
  )))

# split data into training and test using hts test result column ----------------------
df.split = initial_split(hts.facility.df2, strata = hts_test_result) # default split if .75/.25

train.df = training(df.split)
test.df = testing(df.split)


# recipe for Random Forest model ------------------------------------------------
# use themis package for oversampling: https://github.com/tidymodels/themis 
#  for more info on SMOTE method for unbalanced data refer: https://jair.org/index.php/jair/article/view/10302/24590

hts_recipe = recipe(hts_test_result ~ ., data = train.df) %>%  
  # remove individual data - patient id and facility id and age since age-grouo is already in the dataset
  step_rm(patient_id, datim_code, age) %>% 
  update_role(patient_id, new_role = "patient_ID") %>% 
  update_role(datim_code, new_role = "facility_id") %>% 
  step_dummy(all_nominal(), -all_outcomes()) %>% 
  # # normalize numeric variables
  step_normalize(all_predictors()) %>% 
  # downsample positive tests as they are 90% of the results -
  themis::step_smote(hts_test_result, over_ratio = 1)

hts_tree_prep <- prep(hts_recipe)

# create the data frame
hts_juiced <- juice(hts_tree_prep)

xgb_spec <- boost_tree(
  trees = 500,
  tree_depth = tune(), min_n = tune(),
  loss_reduction = tune(),                     ## first three: model complexity
  sample_size = tune(), mtry = tune(),         ## randomness
  learn_rate = tune(),                         ## step size
) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

#  set up grid for tuning values -------------------
xgb_grid <- grid_latin_hypercube(
  tree_depth(),
  min_n(),
  loss_reduction(),
  sample_size = sample_prop(),
  finalize(mtry(), train.df),
  learn_rate(),
  size = 20
)

xgb_grid


# set up workflow ---------------------------------------------------------


xgb_wf <- workflow() %>%
  add_formula(hts_test_result ~ .) %>%
  add_model(xgb_spec)

set.seed(123)
vb_folds <- vfold_cv(train.df, strata = hts_test_result)
vb_folds


# tune the model ----------------------------------------------------------

doParallel::registerDoParallel()


set.seed(234)
xgb_res <- tune::tune_grid(
  xgb_wf,
  resamples = vb_folds,
  grid = xgb_grid,
  control = control_grid(save_pred = TRUE)
)

这是过去 12 小时卡住的地方。我的数据集这么小,为什么要花这么长时间?

标签: rmachine-learningxgboosttidymodels

解决方案


推荐阅读