r - 使用 TIdymodels 调优 XGBoost - 12 小时后仍未完成
问题描述
我已经在高性能机器(4Ghz,16 核,32gb RAM)上运行 R 中的 XGBoost 模型超过 12 小时,但仍未完成。我不确定出了什么问题。我跟着Julia Silge 的博客来到了发球台。这是我的数据的样子:
str(hts.facility.df)
tibble [24,422 x 47] (S3: tbl_df/tbl/data.frame)
$ patient_id : Factor w/ 24422 levels
$ datim_code : chr [1:24422]
$ sex : Factor w/ 2 levels "F","M": 2 1 1 1 1 1 1 1 2 1 ...
$ age : num [1:24422] 33 36 29 21 49 44 71 26 50 38 ...
$ age_group : Factor w/ 12 levels "< 1","1 - 4",..: 7 8 6 5 10 9 12 6 12 8 ...
$ referred_from : Factor w/ 3 levels "Self","TB","Other": 2 1 1 1 1 1 1 1 1 1 ...
$ marital_status : Factor w/ 4 levels "M","S","W","D": 1 1 2 1 1 1 3 2 1 2 ...
$ no_of_own_children_lessthan_5 : Factor w/ 2 levels "0","more_than_2_children": 2 1 1 1 1 1 1 1 1 1 ...
$ no_of_wives : Factor w/ 2 levels "0","more_than_2_wives": 2 1 1 1 1 1 1 1 1 1 ...
$ session_type : Factor w/ 2 levels "Couple","Individual": 2 2 2 2 2 2 2 2 2 2 ...
$ previously_tested_hiv_negative : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...
$ client_pregnant : Factor w/ 2 levels "1","0": 2 1 1 1 1 1 1 1 2 1 ...
$ hts_test_result : Factor w/ 2 levels "Neg","Pos": 1 1 1 1 1 1 1 1 1 1 ...
$ hts_setting : Factor w/ 4 levels "CT","TB","Ward",..: 3 3 3 3 3 3 3 3 3 3 ...
$ tested_for_hiv_before_within_this_year: Factor w/ 2 levels "PreviouslyTestedNegative",..: 2 1 2 2 2 2 2 2 2 2 ...
$ is_surge_site : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 1 1 1 ...
$ nga_agesex_f_15_2019 : num [1:24422] 0.0627 0.0627 0.0627 0.0627 0.0627 ...
$ nga_agesex_f_20_2019 : num [1:24422] 0.0581 0.0581 0.0581 0.0581 0.0581 ...
$ nga_agesex_f_25_2019 : num [1:24422] 0.0411 0.0411 0.0411 0.0411 0.0411 ...
$ nga_agesex_f_30_2019 : num [1:24422] 0.0314 0.0314 0.0314 0.0314 0.0314 ...
$ nga_agesex_f_35_2019 : num [1:24422] 0.0275 0.0275 0.0275 0.0275 0.0275 ...
$ nga_agesex_f_40_2019 : num [1:24422] 0.021 0.021 0.021 0.021 0.021 ...
$ nga_agesex_f_45_2019 : num [1:24422] 0.0166 0.0166 0.0166 0.0166 0.0166 ...
$ nga_agesex_m_15_2019 : num [1:24422] 0.0536 0.0536 0.0536 0.0536 0.0536 ...
$ nga_agesex_m_20_2019 : num [1:24422] 0.0632 0.0632 0.0632 0.0632 0.0632 ...
$ nga_agesex_m_25_2019 : num [1:24422] 0.0534 0.0534 0.0534 0.0534 0.0534 ...
$ nga_agesex_m_30_2019 : num [1:24422] 0.036 0.036 0.036 0.036 0.036 ...
$ nga_agesex_m_35_2019 : num [1:24422] 0.0325 0.0325 0.0325 0.0325 0.0325 ...
$ nga_agesex_m_40_2019 : num [1:24422] 0.0263 0.0263 0.0263 0.0263 0.0263 ...
$ nga_agesex_m_45_2019 : num [1:24422] 0.0236 0.0236 0.0236 0.0236 0.0236 ...
$ IHME_CONDOM_LAST_TIME_PREV_MEAN_2017 : num [1:24422] 14.1 14.1 14.1 14.1 14.1 ...
$ IHME_HAD_INTERCOURSE_PREV_MEAN_2017 : num [1:24422] 63.1 63.1 63.1 63.1 63.1 ...
$ IHME_HIV_COUNT_MEAN_2017 : num [1:24422] 0.0126 0.0126 0.0126 0.0126 0.0126 ...
$ IHME_IN_UNION_PREV_MEAN_2017 : num [1:24422] 56.9 56.9 56.9 56.9 56.9 ...
$ IHME_MALE_CIRCUMCISION_PREV_MEAN_2017 : num [1:24422] 98.7 98.7 98.7 98.7 98.7 ...
$ IHME_PARTNER_AWAY_PREV_MEAN_2017 : num [1:24422] 13.5 13.5 13.5 13.5 13.5 ...
$ IHME_PARTNERS_YEAR_MN_PREV_MEAN_2017 : num [1:24422] 13.5 13.5 13.5 13.5 13.5 ...
$ IHME_PARTNERS_YEAR_WN_PREV_MEAN_2017 : num [1:24422] 3.07 3.07 3.07 3.07 3.07 ...
$ IHME_STI_SYMPTOMS_PREV_MEAN_2017 : num [1:24422] 4.15 4.15 4.15 4.15 4.15 ...
$ wp_contraceptive : num [1:24422] 0.282 0.282 0.282 0.282 0.282 ...
$ wp_liveBirths : num [1:24422] 124 124 124 124 124 ...
$ wp_poverty : num [1:24422] 0.555 0.555 0.555 0.555 0.555 ...
$ wp_lit_men : num [1:24422] 0.967 0.967 0.967 0.967 0.967 ...
$ wp_lit_women : num [1:24422] 0.874 0.874 0.874 0.874 0.874 ...
$ wp_stunting_men : num [1:24422] 0.178 0.178 0.178 0.178 0.178 ...
$ wp_stunting_women : num [1:24422] 0.215 0.215 0.215 0.215 0.215 ...
$ road_density_km : num [1:24422] 82.3 82.3 82.3 82.3 82.3 ...
这是我正在运行的代码:
set.seed(4488)
hts.facility.df2 = hts.facility.df %>%
mutate(hts_test_result = as.factor(case_when(
hts_test_result == 'Pos' ~ 1,
hts_test_result == 'Neg' ~ 0
)))
# split data into training and test using hts test result column ----------------------
df.split = initial_split(hts.facility.df2, strata = hts_test_result) # default split if .75/.25
train.df = training(df.split)
test.df = testing(df.split)
# recipe for Random Forest model ------------------------------------------------
# use themis package for oversampling: https://github.com/tidymodels/themis
# for more info on SMOTE method for unbalanced data refer: https://jair.org/index.php/jair/article/view/10302/24590
hts_recipe = recipe(hts_test_result ~ ., data = train.df) %>%
# remove individual data - patient id and facility id and age since age-grouo is already in the dataset
step_rm(patient_id, datim_code, age) %>%
update_role(patient_id, new_role = "patient_ID") %>%
update_role(datim_code, new_role = "facility_id") %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
# # normalize numeric variables
step_normalize(all_predictors()) %>%
# downsample positive tests as they are 90% of the results -
themis::step_smote(hts_test_result, over_ratio = 1)
hts_tree_prep <- prep(hts_recipe)
# create the data frame
hts_juiced <- juice(hts_tree_prep)
xgb_spec <- boost_tree(
trees = 500,
tree_depth = tune(), min_n = tune(),
loss_reduction = tune(), ## first three: model complexity
sample_size = tune(), mtry = tune(), ## randomness
learn_rate = tune(), ## step size
) %>%
set_engine("xgboost") %>%
set_mode("classification")
# set up grid for tuning values -------------------
xgb_grid <- grid_latin_hypercube(
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), train.df),
learn_rate(),
size = 20
)
xgb_grid
# set up workflow ---------------------------------------------------------
xgb_wf <- workflow() %>%
add_formula(hts_test_result ~ .) %>%
add_model(xgb_spec)
set.seed(123)
vb_folds <- vfold_cv(train.df, strata = hts_test_result)
vb_folds
# tune the model ----------------------------------------------------------
doParallel::registerDoParallel()
set.seed(234)
xgb_res <- tune::tune_grid(
xgb_wf,
resamples = vb_folds,
grid = xgb_grid,
control = control_grid(save_pred = TRUE)
)
这是过去 12 小时卡住的地方。我的数据集这么小,为什么要花这么长时间?
解决方案
推荐阅读
- java - TCP 多线程
- bootstrap-4 - 添加引导输入组类时缺少输入框边框
- r - gg在R中绘制时间序列,x轴上有间隔
- mysql - 生成 SQL 查询
- javascript - GET 方法中日期参数的“无效日期”错误
- python - scipy.spatial 的简单 2D Convex Hull 错误
- visual-studio-2017 - Azure 开发操作——将 Visual Studio 2017 (VS2017) 设置为克隆项目的默认值
- git - 在 SmartGit 日志视图中过滤掉合并的分支历史
- jenkins - 使用管道脚本与 Jenkins 一起部署到 CentOS 7.5
- python - 如何处理激活环境的错误?