首页 > 解决方案 > 如何使用 joblib 或 pickle 导出从 KerasClassifier 和 Gridsearchcv 创建的模型?

问题描述

def network(optimizers='rmsprop'):

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Embedding(vocab_size, 100, weights=[embedding_matrix], input_length=length, trainable=True))
    # model.add(LSTM(50, recurrent_dropout=.20))

    model.add(tf.keras.layers.Dense(8, activation='relu')) 
    # model.add(Dense(200, activation='relu'))
    model.add(tf.keras.layers.Dropout(.40))
    # model.add(Flatten())
    model.add(tf.keras.layers.Dense(50, activation='relu'))
    # model.add(Dropout(.3))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(8, activation='sigmoid'))

    #model compile 
    model.compile(optimizer=optimizers, loss='binary_crossentropy', metrics=['accuracy'])
    return model 

num_tags = 8
#create hyperparameter search space 
epochs = [25, 50]
batches = [5, 10, 100]
optimizers = ['rmsprop', 'adam']
neural_network = KerasClassifier(build_fn=network, verbose = 1)
hyperparamters = dict(optimizers=optimizers, epochs=epochs, batch_size = batches)
grid = GridSearchCV(estimator=neural_network, cv=2, param_grid=hyperparamters)
grid.fit(X_train, Ytrain)

我正在尝试使用这样的泡菜保存模型:

import pickle
pickle.dump(grid,open('keras_saved_model.pickle', 'wb'))

这是我得到的错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-24-07d3687ecaaf> in <module>()
      1 import pickle
----> 2 pickle.dump(grid,open('keras_saved_model.pickle', 'wb'))

TypeError: can't pickle _thread._local objects

我也尝试了 joblib,但它显示了同样的错误。

from sklearn.externals import joblib
joblib.dump(grid.best_estimator_, 'keras_saved_model.pkl')

任何帮助将不胜感激!..

标签: pythonscikit-learntf.kerasgridsearchcv

解决方案


Keras 与 pickle 不兼容。如果你愿意修补它,你可以修复它:https ://github.com/tensorflow/tensorflow/pull/39609#issuecomment-683370566 。您还可以使用 SciKeras 库,它会为您执行此操作,并且可以替代KerasClassifierhttps ://github.com/adriangb/scikeras

披露:我是 SciKeras 以及那个 PR 的作者。


推荐阅读