首页 > 解决方案 > 从 SimpleTransformer 的分类模型中提取层输出

问题描述

我已经为文本分类任务微调了一个 BERT 基础模型。现在,我想提取隐藏层输出,以便将此输出与其他特征结合起来训练随机森林模型。问题在于我不知道如何提取隐藏层输出。如果有人能在这方面帮助我,那就太好了。

from simpletransformers.classification import ClassificationModel

model_xlm = ClassificationModel('bert', 'bert-base-uncased')
model_xlm.train_model(df_train)

标签: pythontext-classificationbert-language-modelsimpletransformers

解决方案


在谷歌搜索中遇到了这个问题,最近通过阅读 SimpleTransformers git repo 中的源代码看到了这个问题的答案。

要获得隐藏层输出,只需将output_hidden_​​states传递给模型:

model_xlm = ClassificationModel('bert', 'bert-base-uncased', {"output_hidden_states": True})

那么当你调用 predict 时,你会得到这个:

preds, model_outputs, all_embedding_outputs, all_layer_hidden_states = model.predict(data)

这是源中的参考: https ://github.com/ThilinaRajapakse/simpletransformers/blob/68f0faace15530fa1a738a34ea13521eec4518b1/simpletransformers/classification/classification_model.py#L2242


推荐阅读