python - 如何使用 joblib 或 pickle 导出从 KerasClassifier 和 Gridsearchcv 创建的模型?
问题描述
def network(optimizers='rmsprop'):
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Embedding(vocab_size, 100, weights=[embedding_matrix], input_length=length, trainable=True))
# model.add(LSTM(50, recurrent_dropout=.20))
model.add(tf.keras.layers.Dense(8, activation='relu'))
# model.add(Dense(200, activation='relu'))
model.add(tf.keras.layers.Dropout(.40))
# model.add(Flatten())
model.add(tf.keras.layers.Dense(50, activation='relu'))
# model.add(Dropout(.3))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(8, activation='sigmoid'))
#model compile
model.compile(optimizer=optimizers, loss='binary_crossentropy', metrics=['accuracy'])
return model
num_tags = 8
#create hyperparameter search space
epochs = [25, 50]
batches = [5, 10, 100]
optimizers = ['rmsprop', 'adam']
neural_network = KerasClassifier(build_fn=network, verbose = 1)
hyperparamters = dict(optimizers=optimizers, epochs=epochs, batch_size = batches)
grid = GridSearchCV(estimator=neural_network, cv=2, param_grid=hyperparamters)
grid.fit(X_train, Ytrain)
我正在尝试使用这样的泡菜保存模型:
import pickle
pickle.dump(grid,open('keras_saved_model.pickle', 'wb'))
这是我得到的错误:
TypeError Traceback (most recent call last)
<ipython-input-24-07d3687ecaaf> in <module>()
1 import pickle
----> 2 pickle.dump(grid,open('keras_saved_model.pickle', 'wb'))
TypeError: can't pickle _thread._local objects
我也尝试了 joblib,但它显示了同样的错误。
from sklearn.externals import joblib
joblib.dump(grid.best_estimator_, 'keras_saved_model.pkl')
任何帮助将不胜感激!..
解决方案
Keras 与 pickle 不兼容。如果你愿意修补它,你可以修复它:https ://github.com/tensorflow/tensorflow/pull/39609#issuecomment-683370566 。您还可以使用 SciKeras 库,它会为您执行此操作,并且可以替代KerasClassifier
:https ://github.com/adriangb/scikeras
披露:我是 SciKeras 以及那个 PR 的作者。
推荐阅读
- javascript - 使用 bulma 手风琴,在两个手风琴中,如何通过单击前手风琴内的按钮来显示后手风琴?
- java - 使用二维数组的Java中DP的CoinChange问题
- java - spring-boot REST 响应的最佳实践
- javascript - 如何将动态输入限制为仅一位数并锁定按键,除了使用纯Javascript的旋转按钮?
- html - 如何实现这样的设计风格?
- javascript - 在结果之上使用实时添加对 Firestore 数据进行分页
- java - 在 Xamarin Forms Android 中打印到 POS 打印机
- python - pandas 中按组计算的值
- android - 在 Android Studio 中查找重复代码 (Kotlin)
- reactjs - 如何在单元测试中使用 connected-react-router 的推送