首页 > 解决方案 > 如何在 bert 预训练模型中获得最后一个变压器编码器的所有输出,而不仅仅是 cls 令牌输出?

问题描述

我正在使用 pytorch,这是来自 huggingface transformers链接的模型:

from transformers import BertTokenizerFast, BertForSequenceClassification
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                     num_labels=int(data['class'].nunique()),
                                                     output_attentions=False,
                                                     output_hidden_states=False)

在我正在构建的前向函数中,我正在调用x1, x2 = self.bert(sent_id, attention_mask=mask) 现在,据我所知,x2 是 cls 输出(这是第一个变压器编码器的输出),但我想我还是不明白模型的输出。但我想要所有最后 12 个变压器编码器的输出。我怎么能在 pytorch 中做到这一点?

标签: neural-networkpytorchtext-classificationbert-language-modelhuggingface-transformers

解决方案


Ideally, if you want to look into the outputs of all the layer, you should use BertModel and not BertForSequenceClassification. Because, BertForSequenceClassification is inherited from BertModel and adds a linear layer on top of the BERT model.

from transformers import BertModel
my_bert_model = BertModel.from_pretrained("bert-base-uncased")

### Add your code to map the model to device, data to device, and obtain input_ids and mask

sequence_output, pooled_output = my_bert_model(ids, attention_mask=mask)

# sequence_output has the following shape: (batch_size, sequence_length, 768), which contains output for all tokens in the last layer of the BERT model.

sequence_output contains output for all tokens in the last layer of the BERT model.

In order to obtain the outputs of all the transformer encoder layers, you can use the following:

my_bert_model = BertModel.from_pretrained("bert-base-uncased")
sequence_output, pooled_output, all_layer_output = model(ids, attention_mask=mask, output_hidden_states=True)

all_layer_output is a output tuple containing the outputs embeddings layer + outputs of all the layer. Each element in the tuple will have a shape (batch_size, sequence_length, 768)

Hence, to get the sequence of outputs at layer-5, you can use all_layer_output[5]. As, all_layer_output[0] contains outputs of the embeddings.


推荐阅读