首页 > 解决方案 > Python - 通过传递带有模型名称的字符串作为输入来获取 Scikit-Learn 分类器

问题描述

我想通过将带有模型名称的字符串作为输入传递来检索特定的 SKLearn 模型对象。例如,目前我有这个来加载 MultinomialNB 模型

from sklearn.naive_bayes import MultinomialNB

nb = MultinomialNB(alpha=1.0,
                   class_prior=None,
                   fit_prior=True)

我想要一种方法:

def get_model(model_name):
    (...)
    return model

这样当我执行 get_model("MultinomialNB") 时,我会得到与nb上面代码相​​同的对象。为此在 Scikit-Learn 中实现了什么?

标签: pythonscikit-learn

解决方案


One option is to use importlib. It will force you to also pass the module from which to import, though. With this approach, model hyperparameters should also be parametrized in the function call.

Example:

import importlib
import sklearn

def get_model(
    model_name: str, import_module: str, model_params: dict
) -> sklearn.base.BaseEstimator:
    """Returns a scikit-learn model."""
    model_class = getattr(importlib.import_module(import_module), model_name)
    model = model_class(**model_params)  # Instantiates the model
    return model

which then you could call by doing e.g.

model_params = {"alpha": 1.0, "class_prior": None, "fit_prior": True}
nb = get_model("MultinomialNB", "sklearn.naive_bayes", model_params)

推荐阅读