r - 保存 XGBoost 模型的工作流程以备后用
问题描述
我对 R 比较陌生,因此如果这个问题措辞奇怪或不清楚,请提前道歉。
我基于 tidymodels 框架构建了一个 xgboost 模型,正如 Julia 在她关于该主题的 youtube 视频中所展示的那样(使用 tidymodels 调整 XGBoost - YouTube)。
我现在遇到的问题是我想保存模型以备后用(这样我就可以将其加载到其中而无需再次构建模型),尽管我在网上阅读到使用该saveRDS
函数保存它可能会在打包的情况下导致兼容性问题未来的版本更新。
一个更强大的选项似乎是"xgb.save"
or "xgb.save.raw"
,尽管要利用此功能,模型应该是 class xgb.booster
,而我的模型是 class 工作流。
这导致我想到几个问题:
- 如果对象属于工作流类型,还是仅针对 xgb.booster 类型的对象,是否也会出现兼容性问题?
- 有没有办法在我正在使用的 tidymodels 框架内将工作流转换为 xgb.booster 模型,或者我应该从另一个角度解决这个问题并尝试找到一种更好的方法来保存工作流?
我将不胜感激任何形式的回应。
下面是我要澄清的一些代码(over_under 是我的 DV,代表超过 500 封电子邮件/不到 500 封电子邮件):
xgb_bundes <- boost_tree(
trees = tune(),
tree_depth = tune(),
sample_size = tune(),
mtry = tune(),
learn_rate = 0.06
) %>%
set_engine("xgboost") %>%
set_mode("classification")
xgb_grid <- grid_latin_hypercube(
trees(),
tree_depth(),
sample_size = sample_prop(),
finalize(mtry(), data_train),
size = 30
)
xgb_wf_bundes <- workflow() %>%
add_formula(over_under ~ .) %>%
add_model(xgb_bundes)
set.seed(210)
data_fold_cv<- vfold_cv(data_train, v = 10, strata = over_under)
library(doParallel)
cores<-detectCores()
cl <- makeCluster(cores[1]-1)
#Register cluster
registerDoParallel(cl)
set.seed(285)
xgb_bundes_result <- tune_grid(
xgb_wf_bundes,
resamples = data_fold_cv,
grid = xgb_grid,
control = control_grid(save_pred = T)
)
metrics_results <- xgb_bundes_result %>% collect_metrics()
best <- metrics_results %>% filter(.metric == "accuracy")
final_xgb_bundes <- finalize_workflow(xgb_wf_bundes, best)
model_bundes <- final_xgb_bundes %>% fit(data_train)
class(model_bundes)
[1] "workflow"
xgb.save(model_bundes, "model bundes")
---Which gives the following error:
Error in xgb.save(model_bundes, "model bundes") : model must be xgb.Booster.
解决方案
推荐阅读
- javascript - 如何将本地js文件实现到ejs中
- django - 依赖项引用了一个不存在的父节点('auth', '0012_alter_user_first_name_max_length')
- javascript - 获取表中的所有行值
- javascript - 如何在没有按钮的下拉菜单上使用onclick
- dolphindb - 在 dolphindb 的列中动态创建和追加数据
- android - Jetpack Compose 中未显示 Snackbar
- spring - Feign 的问题 - 向包含“/api”的 url 发出请求返回 403
- python-3.x - Discord.py 如果消息是 gif
- python - Jinja IF 语句不适用于脚本 html 标签
- flutter - 颤振容器底部中心对齐在堆栈中未在底部完全对齐?