首页 > 解决方案 > 如何将经过训练的 xgboost 基础模型参数加载到 xgboost sklearn API 中?

问题描述

当使用基本 API(即 xgboost.train(args))在 Python 中训练和保存 xgboost 模型时,我们可以使用 .save_model() 保存参数:

import xgboost

model = xgboost.train(args)  # Learning API
model.save_model(args)

loaded_model = xgboost.XGBRegressor()  # Scikit-Learn API
loaded_model.load_model(args)

我们如何将这个经过训练的模型加载到 xgboost sklearn API 中?我的目标是将经过训练的 xgboost 模型(使用 Learning API 训练)作为拟合模型加载到 xgboost Scikit-Learn API 中,这样我就可以利用其他 sklearn 函数。

我在上面代码中包含的方法不允许加载的模型与其他 sklearn 函数一起使用,当我尝试在模型上使用其他 sklearn 函数时,我得到了 NotFittedError。

这是我正在使用的模型的 Python API 的链接:https ://xgboost.readthedocs.io/en/latest/python/python_api.html

我正在使用“Learning API”训练模型并尝试将模型加载到“Scikit-Learn API”中。

标签: pythonscikit-learnxgboost

解决方案


假设您使用了标准分类器或scikit-learn包模型之一,您可以使用以下方法保存和加载模型pickle

import pickle
 
model.train(X)
saved_model = pickle.dumps(model)
 
# Load the pickled model
loaded_model = pickle.loads(saved_model)
 
# Using the loaded model to predict new data
loaded_model.predict(X_test)

您还可以将其保存saved_model到任意文件中,然后再加载。

import pickle
 
model.train(X)

file_pi = open('model.obj', 'w') 
pickle.dump(model, file_pi)

 
# Load the pickled model
filehandler = open(filename, 'r') 
loaded_model = pickle.load(filehandler)
 
# Using the loaded model to predict new data
loaded_model.predict(X_test)

推荐阅读