首页 > 解决方案 > Transformer 模型中自注意力的计算复杂性

问题描述

我最近浏览了 Google Research 的Transformer论文,描述了自注意力层如何完全取代传统的基于 RNN 的序列编码层进行机器翻译。n在论文的表 1 中,作者比较了不同序列编码层的计算复杂度,并指出(稍后)当序列长度小于向量表示的维度时,自注意力层比 RNN 层更快d

然而,如果我对计算的理解是正确的,self-attention 层的复杂性似乎比声称的要低。让X成为自注意力层的输入。然后,X将具有形状,(n, d)因为每个维度都有n词向量(对应于行)d。计算 self-attention 的输出需要以下步骤(为简单起见,考虑单头 self-attention):

  1. 线性变换 的行X以计算 query Q、 keyK和 valueV矩阵,每个矩阵都有 shape (n, d)。这是通过X与 3 个学习的形状矩阵进行后乘来实现的(d, d),计算复杂度为O(n d^2).
  2. 计算层输出,在论文的公式 1 中指定为SoftMax(Q Kt / sqrt(d)) V,其中对每一行计算 softmax。计算Q Kt具有复杂性O(n^2 d),将结果与后乘V也具有复杂性O(n^2 d)

因此,该层的总复杂度为O(n^2 d + n d^2),比传统的 RNN 层要差。dk在考虑适当的中间表示维度 ( , dv) 并最终乘以头数时,我也获得了多头注意力的相同结果h

为什么作者在报告总计算复杂度时忽略了计算查询、键和值矩阵的成本?

我知道所提议的层在各个n职位之间是完全可并行的,但我相信表 1 无论如何都没有考虑到这一点。

标签: machine-learningdeep-learningneural-networknlpartificial-intelligence

解决方案


首先,您的复杂性计算是正确的。那么,混乱的根源是什么?

最初的Attention 论文刚被引入时,它不需要计算Q和矩阵VK因为值直接取自 RNN 的隐藏状态,因此 Attention 层的复杂度 O(n^2·d)

现在,要了解Table 1包含的内容,请记住大多数人是如何扫描论文的:他们阅读标题、摘要,然后查看图表。只有在结果有趣的情况下,他们才会更彻底地阅读论文。因此,本文的主要思想是Attention is all you need在 seq2seq 设置中用注意力机制完全替换 RNN 层,因为 RNN 的训练速度非常慢。如果你Table 1在这个上下文中查看,你会发现它比较了 RNN、CNN 和 Attention,并突出了论文的动机:使用 Attention 应该比 RNN 和 CNN 更有利。它应该在 3 个方面具有优势:恒定数量的计算步骤、恒定数量的操作以及对于通常的 Google 设置的较低计算复杂度,其中n ~= 100d ~= 1000. 但正如任何想法一样,它撞上了现实的硬墙。实际上,为了让这个好主意发挥作用,他们必须添加位置编码,重新制定注意力并为其添加多个头。结果是 Transformer 架构,虽然其计算复杂度O(n^2·d + n·d^2)仍然比 RNN 快得多(在挂钟时间的意义上),并产生更好的结果。

所以你的问题的答案是作者所指的注意力层Table 1严格来说是注意力机制。这不是 Transformer 的复杂性。他们非常清楚他们模型的复杂性(我引用):

然而,可分离卷积 [6] 将复杂度大大降低到O(k·n·d + n·d^2). 然而,即使使用k = n,可分离卷积的复杂度也等于自注意力层和逐点前馈层的组合,这是我们在模型中采用的方法。


推荐阅读