首页 > 解决方案 > 如何理解transformer中的masked multi-head attention

问题描述

我目前正在研究transformer的代码,但我无法理解decoder的masked multi-head。论文说是为了不让你看到生成词,但是如果生成词之后的词还没有生成,我就无法理解,怎么能看到呢?

我尝试阅读变压器的代码(链接:https ://github.com/Kyubyong/transformer )。代码实现的掩码如下所示。它使用下三角矩阵来掩盖,我不明白为什么。

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)

标签: tensorflowdeep-learningtransformerattention-model

解决方案


阅读Transformer 论文后,我也有同样的问题。我在互联网上没有找到这个问题的完整和详细的答案,所以我将尝试解释我对 Masked Multi-Head Attention 的理解。

简短的回答是——我们需要掩蔽来使训练并行。并且并行化很好,因为它可以让模型训练得更快。

这是一个解释这个想法的例子。假设我们训练将“我爱你”翻译成德语。编码器以并行模式工作——它可以在恒定步数内生成输入序列(“我爱你”)的矢量表示(即步数不取决于输入序列的长度)。

假设编码器产生数字11, 12, 13作为输入序列的向量表示。实际上,这些向量会更长,但为简单起见,我们使用较短的向量。同样为简单起见,我们忽略了服务令牌,例如 - 序列的开头, - 序列的结尾等。

在训练过程中,我们知道翻译应该是“Ich liebe dich”(我们总是知道训练过程中的预期输出)。假设“Ich liebe dich”词的预期向量表示是21, 22, 23

如果我们以顺序模式训练解码器,它看起来就像是循环神经网络的训练。将执行以下顺序步骤:

  • 顺序操作#1。输入:11, 12, 13
    • 试图预测21
    • 预测的输出不会是准确21的,假设它会是21.1
  • 顺序操作#2。输入:11, 12, 13,也21.1作为之前的输出。
    • 试图预测22
    • 预测的输出不会是准确22的,假设它会是22.3
  • 顺序操作#3。输入11, 12, 13,也22.3作为之前的输出。
    • 试图预测23
    • 预测的输出不会是准确23的,假设它会是23.5

这意味着我们需要进行 3 个顺序操作(一般情况下 - 每个输入一个顺序操作)。此外,我们将在每次下一次迭代中累积错误。此外,我们不使用注意力,因为我们只查看单个先前的输出。

正如我们实际上知道预期的输出一样,我们可以调整过程并使其并行。无需等待上一步输出。

  • 并行操作#A。输入:11, 12, 13
    • 试图预测21
  • 并行操作#B。输入:11, 12, 13,还有21
    • 试图预测22
  • 并行操作#C。输入:11, 12, 13,还有21, 22
    • 试图预测23

该算法可以并行执行,也不会累积错误。该算法使用注意力(即查看所有先前的输入),因此在进行预测时有更多关于上下文的信息要考虑。

这是我们需要掩蔽的地方。训练算法知道整个预期输出 ( 21, 22, 23)。它为每个并行操作隐藏(屏蔽)这个已知输出序列的一部分。

  • 当它执行 #A - 它隐藏(屏蔽)整个输出。
  • 当它执行 #B - 它隐藏第二和第三输出。
  • 当它执行 #C - 它隐藏第三个输出。

掩蔽本身实现如下(来自原始论文):

我们通过屏蔽掉(设置为 -∞)softmax 输入中与非法连接对应的所有值来在缩放点积注意力内部实现这一点

注意:在推理(非训练)期间,解码器以顺序(非并行)模式工作,因为它最初不知道输出序列。但它与 RNN 方法不同,因为 Transformer 推理仍然使用自我注意并查看所有先前的输出(但不仅仅是前一个输出)。

注意 2:我在一些材料中看到,掩蔽可以不同地用于非翻译应用程序。例如,对于语言建模,掩蔽可用于从输入句子中隐藏一些单词,并且模型将尝试在训练期间使用其他非掩蔽单词来预测它们(即学习理解上下文)。


推荐阅读