pytorch - 没有分类层的拥抱脸变压器伯特模型
问题描述
我想从 vgg16 进行联合嵌入并bert
进行分类。
问题huggingface transformers bert
是它具有具有num_labels
维度的分类层。
但是,我想要BertPooler
(768 维)的输出,我将其用作扩展模型的文本嵌入。
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
这给出了以下模型:
BertForSequenceClassification(
...
...
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
我怎样才能摆脱classifier
层?
解决方案
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')
输出
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
在此处查看 BertModel 定义。
推荐阅读
- java - 如何在java中提取前两个字符和后两个字符之间的字符串?
- time - 如何验证时间是否在SSRS中的时间范围内
- javascript - 无法合并两个对象数组
- kubernetes - 工作节点上的 kubernetes 内存承诺限制是什么?
- javascript - 理解 JavaScript 中的逻辑非运算符
- opencv - 删除“从区域图像中的文本和地图标记”?
- java - 给定一个无限序列,将其分解为多个区间,并返回一个新的无限序列,其中包含每个区间的平均值
- sql - 如何使总和匹配位于不同行中的两个 ID - Redshift
- javascript - 返回带有回调函数参数的函数
- python - Python:当我覆盖它们时,基类属性和方法会发生什么?他们还“存在”吗?