首页 > 解决方案 > 包含 Keras 模型的网格搜索投票分类器

问题描述

我正在VotingClassifier尝试Keras使用GridSearchCV.

这是代码:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import adam
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.ensemble import VotingClassifier
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

# pretend data
X, y = datasets.make_classification(n_samples=100, n_features=20)

scaler = StandardScaler()

# create model
def create_model():   
    
    model = Sequential()

    model.add(Dense(20, kernel_initializer="uniform", activation='relu', input_shape=(20,)))
    model.add(Dense(30, kernel_initializer="uniform", activation='relu'))
    model.add(Dense(10, kernel_initializer="uniform", activation='relu'))
    model.add(Dense(1,  activation='sigmoid'))

    # Compile model
    optimizer = adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    return model

keras_model = KerasClassifier(build_fn=create_model)
keras_model._estimator_type = "classifier"   

eclf = VotingClassifier(
          estimators=[('svc',SVC(probability=True)), ('keras_model', keras_model)]          
          , voting='soft')

#  Test - fit the viting classifier without grid search
eclf.fit(X, y)

print('The VotingClassifier can be fit outside of gridsearch\n')

# parameters to grid search
params = [{'svc__C':[0.01,0.1]}, ]

grid = GridSearchCV(eclf,params,cv=2,scoring='accuracy', verbose=1)
grid.fit(X,y)

我收到以下错误:

ValueError: The estimator KerasClassifier should be a classifier.

当我在VotingClassifier外部训练时GridSearchCV不会发生错误,但是当我在内部训练时GridSearchCV,我会收到错误消息。另一个问题VotingClassifier with pipelines as estimators有相同的错误(不使用 GridSearch),并由断言 keras 模型是一个分类器的行修复,我也包括:

keras_model._estimator_type = "classifier"

这并没有解决这里的问题。

有什么建议么?

标签: python-3.xkerasscikit-learn

解决方案


推荐阅读