首页 > 解决方案 > 来自 Trax 的 AttentionQKV

问题描述

Trax 实现的 AttentionQKV 层如下:AttentionQKV

def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).
  See `Attention` above for further context/details.
  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )

特别是三个平行的密集层的目的是什么?该层的输入是 q, k, v, mask。为什么 q, k, v 要通过一个密集层?

标签: attention-modeltrax

解决方案


此代码片段是 2017 年介绍 Transformer 模型的Attention is all you need论文第 5 页顶部的方程式的实现。计算如图 2 所示:

在此处输入图像描述

隐藏状态投射到h个注意力头中,这些注意力头并行地进行缩放的点积注意力。投影可以解释为提取与头部相关的信息。然后每个头根据不同的(学习的)标准进行概率检索。


推荐阅读