pytorch - 我可以使用 BERT 作为特征提取器而不对我的特定数据集进行任何微调吗?
问题描述
我正在尝试解决 10 个类的多标签分类任务,其中相对平衡的训练集由约 25K 样本组成,评估集由约 5K 样本组成。
我正在使用拥抱脸:
model = transformers.BertForSequenceClassification.from_pretrained(...
并获得相当不错的结果(ROC AUC = 0.98)。
但是,我目睹了一些我似乎无法理解的奇怪行为-
我添加以下代码行:
for param in model.bert.parameters():
param.requires_grad = False
同时确保模型的其他层被学习,即:
[param[0] for param in model.named_parameters() if param[1].requires_grad == True]
gives
['classifier.weight', 'classifier.bias']
在这样配置时训练模型会产生一些令人尴尬的糟糕结果(ROC AUC = 0.59)。
我的工作假设是开箱即用的预训练 BERT 模型(没有任何微调)应该作为分类层的相对良好的特征提取器。那么,我哪里弄错了?
解决方案
根据我的经验,您的假设出错了
一个开箱即用的预训练 BERT 模型(没有任何微调)应该作为分类层的相对较好的特征提取器。
在尝试使用 BERT 的输出层作为词嵌入值时,我注意到了类似的经历,几乎没有微调,结果也很差;这也是有道理的,因为您有效地768*num_classes
以最简单的输出层形式建立了连接。与 BERT 的数百万个参数相比,这使您对高度模型复杂性的控制几乎可以忽略不计。然而,我也想在训练你的完整模型时谨慎地指出过度拟合的结果,尽管我相信你知道这一点。
BERT 的整个想法是微调模型非常便宜,因此为了获得理想的结果,我建议不要冻结任何层。禁用至少部分层可能会有所帮助的一个实例是嵌入组件,具体取决于模型的词汇量(BERT-base 约为 30k)。
推荐阅读
- reactjs - 如何使用 React.js 和 axios 将 POST 响应发送到另一个函数?
- javascript - 单击按钮后React中的重复组件
- delphi - DBGrid 上的行组的替代颜色
- swift - 如何允许在 UITest 中进行跟踪?
- firebase - 我怎样才能从 Flutter 中的 firestore 只获取我的消息
- node.js - 错误:Route.get() 需要一个回调函数,但得到一个 [object String]
- swift - 在按钮上单击从 UITableView 中删除元素 - 基本问题
- python - 如何更改熊猫的列类型
- javascript - 地图和地图异步之间的区别
- docker - Apache Proxy 背后的 Gitlab-runner 交互式 Web 终端