tensorflow - 如何理解transformer中的masked multi-head attention
问题描述
我目前正在研究transformer的代码,但我无法理解decoder的masked multi-head。论文说是为了不让你看到生成词,但是如果生成词之后的词还没有生成,我就无法理解,怎么能看到呢?
我尝试阅读变压器的代码(链接:https ://github.com/Kyubyong/transformer )。代码实现的掩码如下所示。它使用下三角矩阵来掩盖,我不明白为什么。
padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
解决方案
阅读Transformer 论文后,我也有同样的问题。我在互联网上没有找到这个问题的完整和详细的答案,所以我将尝试解释我对 Masked Multi-Head Attention 的理解。
简短的回答是——我们需要掩蔽来使训练并行。并且并行化很好,因为它可以让模型训练得更快。
这是一个解释这个想法的例子。假设我们训练将“我爱你”翻译成德语。编码器以并行模式工作——它可以在恒定步数内生成输入序列(“我爱你”)的矢量表示(即步数不取决于输入序列的长度)。
假设编码器产生数字11, 12, 13
作为输入序列的向量表示。实际上,这些向量会更长,但为简单起见,我们使用较短的向量。同样为简单起见,我们忽略了服务令牌,例如 - 序列的开头, - 序列的结尾等。
在训练过程中,我们知道翻译应该是“Ich liebe dich”(我们总是知道训练过程中的预期输出)。假设“Ich liebe dich”词的预期向量表示是21, 22, 23
。
如果我们以顺序模式训练解码器,它看起来就像是循环神经网络的训练。将执行以下顺序步骤:
- 顺序操作#1。输入:
11, 12, 13
。- 试图预测
21
。 - 预测的输出不会是准确
21
的,假设它会是21.1
。
- 试图预测
- 顺序操作#2。输入:
11, 12, 13
,也21.1
作为之前的输出。- 试图预测
22
。 - 预测的输出不会是准确
22
的,假设它会是22.3
。
- 试图预测
- 顺序操作#3。输入
11, 12, 13
,也22.3
作为之前的输出。- 试图预测
23
。 - 预测的输出不会是准确
23
的,假设它会是23.5
。
- 试图预测
这意味着我们需要进行 3 个顺序操作(一般情况下 - 每个输入一个顺序操作)。此外,我们将在每次下一次迭代中累积错误。此外,我们不使用注意力,因为我们只查看单个先前的输出。
正如我们实际上知道预期的输出一样,我们可以调整过程并使其并行。无需等待上一步输出。
- 并行操作#A。输入:
11, 12, 13
。- 试图预测
21
。
- 试图预测
- 并行操作#B。输入:
11, 12, 13
,还有21
。- 试图预测
22
。
- 试图预测
- 并行操作#C。输入:
11, 12, 13
,还有21, 22
。- 试图预测
23
。
- 试图预测
该算法可以并行执行,也不会累积错误。该算法使用注意力(即查看所有先前的输入),因此在进行预测时有更多关于上下文的信息要考虑。
这是我们需要掩蔽的地方。训练算法知道整个预期输出 ( 21, 22, 23
)。它为每个并行操作隐藏(屏蔽)这个已知输出序列的一部分。
- 当它执行 #A - 它隐藏(屏蔽)整个输出。
- 当它执行 #B - 它隐藏第二和第三输出。
- 当它执行 #C - 它隐藏第三个输出。
掩蔽本身实现如下(来自原始论文):
我们通过屏蔽掉(设置为 -∞)softmax 输入中与非法连接对应的所有值来在缩放点积注意力内部实现这一点
注意:在推理(非训练)期间,解码器以顺序(非并行)模式工作,因为它最初不知道输出序列。但它与 RNN 方法不同,因为 Transformer 推理仍然使用自我注意并查看所有先前的输出(但不仅仅是前一个输出)。
注意 2:我在一些材料中看到,掩蔽可以不同地用于非翻译应用程序。例如,对于语言建模,掩蔽可用于从输入句子中隐藏一些单词,并且模型将尝试在训练期间使用其他非掩蔽单词来预测它们(即学习理解上下文)。
推荐阅读
- android - 当我检测到某个事件时如何让我的应用弹出?
- sql - 语言相关的列标题
- javascript - 如何在reactJs中单击特定单元格的2D网格中更改框的颜色?
- javascript - 如何设置从 app.post 发送的响应的样式?
- html - 如何更改 flexbox 中图片的大小?
- c# - Microsoft-Windows-Security-SPP 记录级别等于 0 的信息
- python - 代码片段在第一次运行时提供错误消息,但第二次将完美运行而无需进行任何更改
- google-calendar-api - 有没有办法通过 nodemailer 可靠地将日历事件添加到用户日历?
- amazon-web-services - 如何使我的 meta-aws 与我的 yocto 兼容
- python - Combine all tiff images into one single image