nlp - BERT:作为掩码语言模型一部分的输入嵌入权重
问题描述
我查看了 BERT 掩码语言模型的不同实现。对于预训练,有两个常见的版本:
- 解码器将简单地采用 [MASK]ed 标记的最终嵌入并将其通过线性层(无需任何修改):
class LMPrediction(nn.Module):
def __init__(self, hidden_size, vocab_size):
super().__init__()
self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
self.bias = nn.Parameter(torch.zeros(vocab_size))
self.decoder.bias = self.bias
def forward(self, x):
return self.decoder(x)
- 一些实现将使用输入嵌入的权重作为解码器线性层的权重:
class LMPrediction(nn.Module):
def __init__(self, hidden_size, vocab_size, embeddings):
super().__init__()
self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
self.bias = nn.Parameter(torch.zeros(vocab_size))
self.decoder.weight = embeddings.weight ## <- THIS LINE
self.decoder.bias = self.bias
def forward(self, x):
return self.decoder(x)
哪一个是正确的?大多数情况下,我看到了第一个实现。但是,第二个也很有意义-但我在任何论文中都找不到它(我想看看第二个版本是否在某种程度上优于第一个)
解决方案
对于那些感兴趣的人,它被称为权重绑定或联合输入输出嵌入。有两篇论文论证了这种方法的好处:
推荐阅读
- android - Android Firebase UI 更改默认错误消息
- javascript - 如何使用 react-native 将 this.props 传递给 js 文件
- c# - 找到两个拟合的交点
- shell - 根据模式在文件中插入一行
- javascript - 如何修剪
- 和
- 标签,但值
- 会保持不变吗?
- r - 选择 CSV 文件并成对读取
- java - Map.Entry::getKey 在 groupby 下抛出错误
- angular - 在角度谷歌地图中计算两个纬度和经度之间的旅行时间
- database - 在 PostgreSQL 中获取最近创建的模式
- puppet - 木偶Enterprose的内置功能