首页 > 解决方案 > 调整超参数时出现 GridSearchCV 错误

问题描述

在为 GridSearch 进行超参数调整时出现奇怪的错误。我从 randomsearchcv 中获得了一些最佳参数,我正在尝试将这些参数拟合到网格搜索 cv 中。我收到错误

参数(标准)的参数网格需要是列表或numpy数组,但得到(<class'str'>)。单个值需要包含在一个包含一个元素的列表中。

下面是代码

from sklearn.model_selection import GridSearchCV
clf=RandomForestClassifier()
n_estimators=rf_random_tuned.best_params_['n_estimators']
criterion=rf_random_tuned.best_params_['criterion']
max_depth=rf_random_tuned.best_params_['max_depth']
min_samples_split=rf_random_tuned.best_params_['min_samples_split']
min_samples_leaf=rf_random_tuned.best_params_['min_samples_leaf']
max_features=rf_random_tuned.best_params_['max_features']
param_grid_1={'n_estimators':[n_estimators-100,n_estimators,n_estimators+100],
           'criterion':criterion,
           'max_depth':[max_depth-1,max_depth-0.5,max_depth,max_depth+0.5,max_depth+1],
           'min_samples_split':[min_samples_split-14,min_samples_split,min_samples_split+14],
           'min_samples_leaf':[min_samples_leaf-0.16,min_samples_leaf,min_samples_leaf+0.16],
           'max_features':max_features
          }
rf_grid=GridSearchCV(estimator=clf,param_grid=param_grid_1,cv=5)
rf_grid.fit(X_train,y_train)

标签: scikit-learnhyperparameters

解决方案


文档

param_grid:字典或字典列表

以参数名称 (str) 作为键的字典和要尝试作为值的参数设置列表,或此类字典的列表,在这种情况下,将探索列表中每个字典跨越的网格。这可以搜索任何参数设置序列。

基本上,它抱怨 wrt 键'criterion''max_features',其值应作为列表传递。


推荐阅读