python-3.x - 包含 Keras 模型的网格搜索投票分类器
问题描述
我正在VotingClassifier
尝试Keras
使用GridSearchCV
.
这是代码:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import adam
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.ensemble import VotingClassifier
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
# pretend data
X, y = datasets.make_classification(n_samples=100, n_features=20)
scaler = StandardScaler()
# create model
def create_model():
model = Sequential()
model.add(Dense(20, kernel_initializer="uniform", activation='relu', input_shape=(20,)))
model.add(Dense(30, kernel_initializer="uniform", activation='relu'))
model.add(Dense(10, kernel_initializer="uniform", activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
optimizer = adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
return model
keras_model = KerasClassifier(build_fn=create_model)
keras_model._estimator_type = "classifier"
eclf = VotingClassifier(
estimators=[('svc',SVC(probability=True)), ('keras_model', keras_model)]
, voting='soft')
# Test - fit the viting classifier without grid search
eclf.fit(X, y)
print('The VotingClassifier can be fit outside of gridsearch\n')
# parameters to grid search
params = [{'svc__C':[0.01,0.1]}, ]
grid = GridSearchCV(eclf,params,cv=2,scoring='accuracy', verbose=1)
grid.fit(X,y)
我收到以下错误:
ValueError: The estimator KerasClassifier should be a classifier.
当我在VotingClassifier
外部训练时GridSearchCV
不会发生错误,但是当我在内部训练时GridSearchCV
,我会收到错误消息。另一个问题VotingClassifier with pipelines as estimators有相同的错误(不使用 GridSearch),并由断言 keras 模型是一个分类器的行修复,我也包括:
keras_model._estimator_type = "classifier"
这并没有解决这里的问题。
有什么建议么?
解决方案
推荐阅读
- php - Swift_TransportException:预期的响应代码 220,但得到一个空响应 - laravel
- c - 在 bin 排序算法中获取 malloc 断言错误
- visual-studio-code - 如何将 Visual Studio Code 中的面板移动到象限?
- apache-camel - 聚合apache骆驼中的optimisticLocking是什么意思?
- html - HTML 表单中的选择和复选框值是否支持非拉丁字符?
- azure - Azure 应用服务中的证书排除路径
- node.js - 如何强制将此图像转换为同步模式?
- python - 如何使用 Python Pandas 比较两个不同大小的数据集?
- c# - 停止 FileSystemWatcher 后运行另一段代码
- python - dockerised python解释器抱怨在docker中绑定安装源时缺少包