scikit-learn - 使用 RFECV 和 GridSearchCV 堆叠 StandardScaler()
问题描述
所以我发现 StandardScaler() 可以让我的 RFECV 在我的 GridSearchCV 中,每个嵌套的 3 折交叉验证运行得更快。如果没有 StandardScaler(),我的代码运行了超过 2 天,所以我取消并决定将 StandardScaler 注入到进程中。但现在它已经运行了 4 个多小时,我不确定我是否做得对。这是我的代码:
# Choose Linear SVM as classifier
LSVM = SVC(kernel='linear')
selector = RFECV(LSVM, step=1, cv=3, scoring='f1')
param_grid = [{'estimator__C': [0.001, 0.01, 0.1, 1, 10, 100]}]
clf = make_pipeline(StandardScaler(),
GridSearchCV(selector,
param_grid,
cv=3,
refit=True,
scoring='f1'))
clf.fit(X, Y)
老实说,我认为我的做法并不正确,因为我认为 StandardScaler() 应该放在 GridSearchCV() 函数中,以便每次折叠都标准化数据,而不仅仅是一次(?)。如果我错了或者我的管道不正确以及为什么它仍然运行了很长时间,请纠正我。
我有 8,000 行的 145 个特征要被 RFECV 修剪,以及 6 个 C 值要被 GridSearchCV 修剪。因此对于每个 C-Value,最佳特征集由 RFECV 确定。
谢谢!
更新:
所以我将 StandardScaler 放在 RFECV 中,如下所示:
clf = SVC(kernel='linear')
kf = KFold(n_splits=3, shuffle=True, random_state=0)
estimators = [('standardize' , StandardScaler()),
('clf', clf)]
class Mypipeline(Pipeline):
@property
def coef_(self):
return self._final_estimator.coef_
@property
def feature_importances_(self):
return self._final_estimator.feature_importances_
pipeline = Mypipeline(estimators)
rfecv = RFECV(estimator=pipeline, cv=kf, scoring='f1', verbose=10)
param_grid = [{'estimator__svc__C': [0.001, 0.01, 0.1, 1, 10, 100]}]
clf = GridSearchCV(rfecv, param_grid, cv=3, scoring='f1', verbose=10)
但它仍然抛出以下错误:
ValueError: 估计器管道的参数 C 无效(memory=None, steps=[('standardscaler', StandardScaler(copy=True, with_mean=True, >with_std=True)), ('svc', SVC(C=1.0, cache_size =200,class_weight=None,>coef0=0.0,decision_function_shape='ovr',degree=3,gamma='auto',kernel='linear',max_iter=-1,probability=False,random_state=None,shrinking=True , tol=0.001, 详细=False))])。使用 > 检查可用参数列表
estimator.get_params().keys()
。
解决方案
库马尔是对的。此外,您可能想要做的是在 GridSearchCV 中打开详细信息。此外,您可以对 SVC 的迭代次数添加一个限制,从一个非常小的数字(例如 5)开始,以确保问题不在于收敛。
推荐阅读
- javascript - 无法从“App.js”(Expo CLI)解析“react”
- mysql - 如何修复 root 帐户获取“错误 1130 (HY000): Host ... is not allowed to connect to this MariaDB server”
- windows - 无法在 IntelliJ Windows 上安装 Git(退出代码 128)编辑:已解决
- c# - Octokit.net 非常慢
- r - 如何在r中预测未来日期的值
- javascript - 尝试单击 openlayers 地图功能以执行单击锚标记以滚动到书签
- python - 将 pandas 中的 CSV 文件导入到带有解析错误的 pandas 数据框中
- python - 我有什么机会从 Python 调用 CICS 事务或 COBOL 程序
- database - 使用网站数据和功能的颤振应用程序的技巧
- android - 从 Fabric 切换到 Firebase Crashlytics 问题:缺少 Crashlytics 构建 ID