neural-network - 如何在 bert 预训练模型中获得最后一个变压器编码器的所有输出,而不仅仅是 cls 令牌输出?
问题描述
我正在使用 pytorch,这是来自 huggingface transformers链接的模型:
from transformers import BertTokenizerFast, BertForSequenceClassification
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased",
num_labels=int(data['class'].nunique()),
output_attentions=False,
output_hidden_states=False)
在我正在构建的前向函数中,我正在调用x1, x2 = self.bert(sent_id, attention_mask=mask)
现在,据我所知,x2 是 cls 输出(这是第一个变压器编码器的输出),但我想我还是不明白模型的输出。但我想要所有最后 12 个变压器编码器的输出。我怎么能在 pytorch 中做到这一点?
解决方案
Ideally, if you want to look into the outputs of all the layer, you should use BertModel
and not BertForSequenceClassification
. Because, BertForSequenceClassification
is inherited from BertModel
and adds a linear layer on top of the BERT model.
from transformers import BertModel
my_bert_model = BertModel.from_pretrained("bert-base-uncased")
### Add your code to map the model to device, data to device, and obtain input_ids and mask
sequence_output, pooled_output = my_bert_model(ids, attention_mask=mask)
# sequence_output has the following shape: (batch_size, sequence_length, 768), which contains output for all tokens in the last layer of the BERT model.
sequence_output
contains output for all tokens in the last layer of the BERT model.
In order to obtain the outputs of all the transformer encoder layers, you can use the following:
my_bert_model = BertModel.from_pretrained("bert-base-uncased")
sequence_output, pooled_output, all_layer_output = model(ids, attention_mask=mask, output_hidden_states=True)
all_layer_output
is a output tuple containing the outputs embeddings layer + outputs of all the layer. Each element in the tuple will have a shape (batch_size, sequence_length, 768)
Hence, to get the sequence of outputs at layer-5, you can use all_layer_output[5]
. As, all_layer_output[0]
contains outputs of the embeddings.
推荐阅读
- javascript - innerHTML变化的简单JS prev-next导航
- multithreading - runBlocking() 在这种情况下不会阻塞线程
- javascript - 如何从动画端获取 HTML 元素?
- c - 复数,cs 50,为什么 printf 函数在 if 语句之外打印不同?
- python - 使用来自 API 的数据每 5 分钟更新一次 MySQL 数据库
- android - 您如何使用侦听器中的 findViewByID 从父活动中找到视图
- haskell - 新类型的 Haskell Monoid 实例问题
- acumatica - 是否可以在 Acumatica 中使用 SOAP API 导出分支列表
- this - 我并没有真正得到构造函数方法
- python - sklearn train_test_split:当我们得到:“目标是多类但平均值='二进制'......错误时,在哪里添加平均值=无?