首页 > 解决方案 > 在 Python 中保存经过训练的多输入分类算法

问题描述

我开发了一个脚本,根据之前手动标记的反馈预测某些文本的可能标记。我使用了几篇在线文章来帮助我(即:https ://towardsdatascience.com/multi-label-text-classification-with-scikit-learn-30714b7819c5 )。

因为我想要每个标签的概率,所以这是我使用的代码:

NB_pipeline = Pipeline([
    ('clf', OneVsRestClassifier(MultinomialNB(alpha=0.3, fit_prior=True, class_prior=None))),
    ])

predictions_en = {}
for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

preds_en = pd.DataFrame(predictions_en.items())
preds_en = preds_en.sort_values(by=[1], ascending=False)
preds_en = preds_en.reset_index(drop=True)

它非常适合我的目的:它为每个可能的标签返回一个预测。但我的问题是,每次我尝试进行预测时,它都会重新训练算法。我想做的是在脚本中训练算法,保存训练后的算法,将其加载到另一个进行预测的脚本中。

我希望能够在脚本 1 中执行此操作:

for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])

这在另一个脚本中:

for category in categories_en:
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

但我似乎无法让它工作。当我尝试将它分开时,它只是给了我相同的预测。

标签: pythonmachine-learningtext-classificationmulticlass-classification

解决方案


你总是可以pickle用来序列化任何 python 对象,包括你的。因此,保存模型最简单、最快的方法就是将其序列化为一个文件,例如model.pickle. 这是在训练模型后的第一部分完成的。之后,您所要做的就是检查文件是否存在并pickle再次使用反序列化它。

这是一个将 python 对象序列化为文件的函数:

import pickle

def serialize(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

这是一个从文件中反序列化 python 对象的函数:

import pickle

def deserialize(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

完成训练后,您所要做的就是调用(如果NB_pipeline是模型的对象):

serialize(NB_pipeline, 'model.pickle')

当你必须加载它并使用它时,只需调用:

NB_pipeline = deserialize('model.pickle')

推荐阅读