machine-learning - Transformer 模型中自注意力的计算复杂性
问题描述
我最近浏览了 Google Research 的Transformer论文,描述了自注意力层如何完全取代传统的基于 RNN 的序列编码层进行机器翻译。n
在论文的表 1 中,作者比较了不同序列编码层的计算复杂度,并指出(稍后)当序列长度小于向量表示的维度时,自注意力层比 RNN 层更快d
。
然而,如果我对计算的理解是正确的,self-attention 层的复杂性似乎比声称的要低。让X
成为自注意力层的输入。然后,X
将具有形状,(n, d)
因为每个维度都有n
词向量(对应于行)d
。计算 self-attention 的输出需要以下步骤(为简单起见,考虑单头 self-attention):
- 线性变换 的行
X
以计算 queryQ
、 keyK
和 valueV
矩阵,每个矩阵都有 shape(n, d)
。这是通过X
与 3 个学习的形状矩阵进行后乘来实现的(d, d)
,计算复杂度为O(n d^2)
. - 计算层输出,在论文的公式 1 中指定为
SoftMax(Q Kt / sqrt(d)) V
,其中对每一行计算 softmax。计算Q Kt
具有复杂性O(n^2 d)
,将结果与后乘V
也具有复杂性O(n^2 d)
。
因此,该层的总复杂度为O(n^2 d + n d^2)
,比传统的 RNN 层要差。dk
在考虑适当的中间表示维度 ( , dv
) 并最终乘以头数时,我也获得了多头注意力的相同结果h
。
为什么作者在报告总计算复杂度时忽略了计算查询、键和值矩阵的成本?
我知道所提议的层在各个n
职位之间是完全可并行的,但我相信表 1 无论如何都没有考虑到这一点。
解决方案
首先,您的复杂性计算是正确的。那么,混乱的根源是什么?
最初的Attention 论文刚被引入时,它不需要计算Q
和矩阵V
,K
因为值直接取自 RNN 的隐藏状态,因此 Attention 层的复杂度为 O(n^2·d)
。
现在,要了解Table 1
包含的内容,请记住大多数人是如何扫描论文的:他们阅读标题、摘要,然后查看图表。只有在结果有趣的情况下,他们才会更彻底地阅读论文。因此,本文的主要思想是Attention is all you need
在 seq2seq 设置中用注意力机制完全替换 RNN 层,因为 RNN 的训练速度非常慢。如果你Table 1
在这个上下文中查看,你会发现它比较了 RNN、CNN 和 Attention,并突出了论文的动机:使用 Attention 应该比 RNN 和 CNN 更有利。它应该在 3 个方面具有优势:恒定数量的计算步骤、恒定数量的操作以及对于通常的 Google 设置的较低计算复杂度,其中n ~= 100
和d ~= 1000
. 但正如任何想法一样,它撞上了现实的硬墙。实际上,为了让这个好主意发挥作用,他们必须添加位置编码,重新制定注意力并为其添加多个头。结果是 Transformer 架构,虽然其计算复杂度O(n^2·d + n·d^2)
仍然比 RNN 快得多(在挂钟时间的意义上),并产生更好的结果。
所以你的问题的答案是作者所指的注意力层Table 1
严格来说是注意力机制。这不是 Transformer 的复杂性。他们非常清楚他们模型的复杂性(我引用):
然而,可分离卷积 [6] 将复杂度大大降低到
O(k·n·d + n·d^2)
. 然而,即使使用k = n
,可分离卷积的复杂度也等于自注意力层和逐点前馈层的组合,这是我们在模型中采用的方法。
推荐阅读
- c# - 如何将此查询转换为 lamda 表达式
- linux - 每 5 分钟检查一次循环中的特定进程,持续 30 分钟
- flutter - 颤振字符串替换
- android - 重新定位 Google 品牌徽标使其在 Android 上更小
- c# - 我想将 10 条记录作为列表一起传递,作为 SSIS 中脚本组件的输入
- algorithm - 找到没有相邻元素相同的系列的所有排列?
- python - 升级到 tensorflow 2.0 后,我收到导入 tensorflow 的错误
- javascript - 如何将变量解析为 Modal 并从该变量解析为 Ajax?
- html - 在 td 中为 html 电子邮件左右对齐元素
- css - map.get($foo, bar) 被标记为“semi-colon expectedscss(css-semicolonexpected)”