tensorflow - 使用 tfhub 模块时冻结 BERT 层
问题描述
在此链接单击此处,作者说:
import tensorflow_hub as hub
module = hub.Module(<<Module URL as string>>, trainable=True)
如果用户希望微调/修改模型的权重,则必须将此参数设置为 True。所以我怀疑如果我将其设置为 false 是否意味着我正在冻结 BERT 的所有层,这也是我的意图。我想知道我的方法是否正确。
解决方案
我有一个多部分的答案给你。
如何冻结模块
这一切都取决于您的优化器是如何设置的。TF1 的常用方法是使用 TRAINABLE_VARIABLES集合中的所有变量对其进行初始化。hub.Module的文档说trainable
:“如果为 False,则不会将任何变量添加到 TRAINABLE_VARIABLES 集合中,...”。所以,是的,设置trainable=False
(显式或默认)将模块冻结在 TF1 的标准用法中。
为什么不冻结 BERT
也就是说,BERT 是用来微调的。该论文以更一般的术语讨论了基于特征(即冻结)与微调方法,但模块文档清楚地说明了这一点:“微调所有参数是推荐的做法。” 这使计算池输出的最后部分更好地适应手头任务最重要的特征。
如果您打算遵循此建议,请注意tensorflow.org/hub/tf1_hub_module#fine-tuning并选择正确的图形版本:BERT 在训练期间使用dropout正则化,您需要进行设置hub.Module(..., tags={"train"})
才能获得它。但是对于推理(在评估和预测中),dropout 什么都不做,你可以省略tags=
参数(或将其设置为空set()
或 to None
)。
展望:TF2
你问了关于hub.Module()
TF1 的 API,所以我在那个上下文中回答了。同样的注意事项也适用于 TF2 SavedModel 格式的BERT 。在那里,一切都是关于设置hub.KerasLayer(..., trainable=True)
与否,但选择图形版本的需要已经消失(该层获取 Keras 的training
状态并在后台应用它)。
快乐的训练!
推荐阅读
- kotlin - 如何正确地对简单字符串值的响应式响应进行单元测试?
- python - Python正则表达式域名
- vue.js - axios Post 根据请求给出 net::ERR_UNEXPECTED_PROXY_AUTH 错误
- sql - 在日期死亡后查找第一笔交易
- javascript - 无法在 jQuery 的 ajax 成功中分配或访问当前选择框值
- css - 相同的 CSS 在具有 Chrome 的两台 PC 上具有不同的行为
- python - 在 sqlite3 中减少读取数据库冗余的功能
- javascript - 根据包括键名在内的值进行排序
- apache-spark - 有条件的 PySpark 窗口
- arrays - 在 Swift 中通过多个关键字过滤对象