python - 带有 fit_transfrom 或预测对象而不是拟合对象的 sklearn 管道
问题描述
sklearn 网站上的这个例子和这个对 sklearn pipelines on SO 的回答使用和讨论了在 Pipleines 中的使用.fit()
或.fit_transform()
方法。
但是,如何在 Pipelines 中使用 .predict 或 .transfrom 方法。假设我已经预处理了我的训练数据,搜索了最佳超参数并训练了一个 LightGBM 模型。我现在想预测新数据,而不是手动做所有上述事情,我想根据定义一个接一个地做它们:
依次应用变换列表和最终估计器。管道的中间步骤必须是“变换”,即它们必须实现拟合和变换方法。最终估计器只需要实现拟合。
但是,我只想.transform
对我的验证(或测试)数据和更多的函数(或类)实现方法,这些函数(或类)采用 pandas 系列(或 DataFrame 或 numpy 数组)并返回处理过的一个,然后最终实现.predict
我的 LightGBM 的方法,这将使用我已经拥有的超参数。
我目前什么都没有,因为我不知道如何正确地包含类的方法(比如
StandardScaler_instance.transform()
)和更多这样的方法。!
我该怎么做或我错过了什么?
解决方案
您必须构建您的管道,其中将包括 LightGBM 模型并在您的(预处理的)训练数据上训练管道。
使用代码,它可能如下所示:
import lightgbm
from sklearn.pipeline import Pipeline
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Create some train and test data
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Define pipeline with scaler and lightgbm model
pipe = Pipeline([('scaler', StandardScaler()), ('lightgbm', lightgbm.LGBMClassifier())])
# Train pipeline
pipe.fit(X_train, y_train)
# Make predictions with pipeline (with lightgbm)
print("Predictions:", pipe.predict(X_test))
# Evaluate pipeline performance
print("Performance score:", pipe.score(X_test, y_test))
输出:
Predictions: [1 0 1 0 0 0 1 0 1 1 1 0 0 1 0 1 0 0 1 1 1 0 1 0 0]
Performance score: 0.84
所以回答你的问题:
但是,如何在 Pipelines 中使用 .predict 或 .transfrom 方法。
- 您不必使用 .transform,因为管道会使用提供的转换器自动处理输入数据的转换。这就是为什么它在文档中提到:
管道的中间步骤必须是“变换”,即它们必须实现拟合和变换方法。
- 您可以将代码示例中所示的 .predict 用于您的测试数据。
除了我在本示例中使用的 StandardScaler,您可以为管道提供自定义转换器,但它必须实现管道可以调用的 .transform() 和 .fit() 方法,并且转换器的输出需要匹配lightgbm 模型所需的输入。
更新
然后,您可以为管道的不同步骤提供参数,如此处文档中所述:
** fit_paramsdict of string -> object传递给
fit
每个 step 的方法的参数,其中每个参数名称都带有前缀,p
以便 step的参数s
具有 keys__p
。
推荐阅读
- c# - 以编程方式将父 ViewModel 的命令添加到按钮
- python - 如何使用 tf.data.Dataset.from_tensor_slices 和 map 加载 np.array
- django - 将 websockets 集成到 Django Rest Framework 应用程序的简单方法?
- akka - 测试 AKKA 2.6 持久性演员类型(Kill 和 PoisonPill 的替代品)
- sql - 如何解析以下 JSON 字段?(PostgreSQL)
- c# - Google drive api v3 在同一请求中创建多个副本以在 C# 中驱动服务
- python-3.x - 循环内的 Python PyQt5 组合框连接
- python-c-api - 如何正确地将 CuPy 数组发送到 Python C 扩展模块
- stripe-payments - 具有不同持续时间的产品的条带订阅
- npm - 如何从相对版本中获取确切的最新 npm 版本?