首页 > 解决方案 > Pytorch Transformer 模块中 MultiheadAttention 中“heads”的定义

问题描述

我对 Multihead 的定义有点困惑。
下面的 [1] 和 [2] 是否相同?

[1] 我对 multiplhead 的理解是下面的多重注意力模式。
“多组查询/键/值权重矩阵(Transformer 使用八个注意力头,因此我们最终为每个编码器/解码器提供了八组)。”
http://jalammar.github.io/illustrated-transformer/

[2] 在类 MultiheadAttention(Module) 中:在 Pytorch Transformer 模块中,似乎 embed_dim 除以头数.. 为什么?

或者... embed_dim 首先是特征维度乘以头数?

self.head_dim = embed_dim // num_heads 断言 self.head_dim * num_heads == self.embed_dim, "embed_dim 必须能被 num_heads 整除"

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py

标签: pytorchtransformer

解决方案


根据您的理解,多头注意力是对某些数据的多次注意力。

但相比之下,它不是通过将权重集乘以所需注意力的数量来实现的。相反,您重新排列与注意力数量相对应的权重矩阵,即重塑为权重矩阵。所以,从本质上讲,它仍然是多次注意力,但你正在关注权重的不同部分。


推荐阅读