python - keras:返回 model.summary() vs scikit learn wrapper
问题描述
在使用 keras 时,我了解到使用包装器会对 keras 和 scikit learn api 请求产生不利影响。我对两者兼有的解决方案感兴趣。
变体 1:scikit 包装器
from keras.wrappers.scikit_learn import KerasClassifier
def model():
model = Sequential()
model.add(Dense(10, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
model.fit(X, y)
->这让我可以打印诸如 accuracy_score() 或 classification_report() 之类的 scikit 命令。但是, model.summary() 不起作用:
AttributeError:“KerasClassifier”对象没有属性“摘要”
变体 2:无包装
model = Sequential()
model.add(Dense(10, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=100, batch_size=5)
->这让我可以打印 model.summary() 但不能打印 scikit 命令。
ValueError:不允许混合 y 类型,得到类型 {'multiclass'、'multilabel-indicator'}
有没有一种方法可以同时使用两者?
解决方案
KerasClassifier
只是对实际Model
in的包装,keras
因此 keras api 的实际方法可以路由到 scikit 中使用的方法,因此可以与 scikit 实用程序结合使用。但在内部它只使用可以通过 using 访问的模型estimator.model
。
说明上述内容的示例:
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.datasets import make_classification
def model():
model = Sequential()
model.add(Dense(10, input_dim=20, activation='relu'))
model.add(Dense(2, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
X, y = make_classification()
estimator.fit(X, y)
# This is what you need
estimator.model.summary()
输出是:
Layer (type) Output Shape Param #
=================================================================
dense_9 (Dense) (None, 10) 210
_________________________________________________________________
dense_10 (Dense) (None, 2) 22
=================================================================
Total params: 232
Trainable params: 232
Non-trainable params: 0
_________________________________________________________________
推荐阅读
- python-3.x - 我们如何将执行 pandas read_csv() 时生成的警告消息存储到变量中?
- mongodb - 如何验证两个非关系数据库之间的数据一致性?
- javascript - 一个简单的计算器(只加)oop javascript
- sql-server - 将前 x 列从一个表复制到另一个空表
- android - H264 格式没有音频 如何在 h264 中获取音频
- javascript - 无法在编辑页面上更改我的输入类型值
- azure-devops - 在 Azure DevOps 中的构建和发布管道之间共享变量
- angularjs - 如何使用 selenium 单击具有动态 id 的 ng show 元素
- reactjs - react.js 应用程序中 iframe 下的 Reactstrap 模态
- c++ - C++17之前的类模板参数推导