首页 > 解决方案 > MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有什么区别

问题描述

pytorchatt_maskkey_padding_maskin有什么区别:MultiHeadAttnetion

key_padding_mask – 如果提供,key 中指定的填充元素将被注意力忽略。当给定一个二进制掩码并且值为 True 时,注意力层上的相应值将被忽略。当给定一个字节掩码并且一个值不为零时,注意力层上的相应值将被忽略

attn_mask – 2D 或 3D 掩码,可防止对某些位置的注意。将为所有批次广播 2D 掩码,而 3D 掩码允许为每个批次的条目指定不同的掩码。

提前致谢。

标签: pythondeep-learningpytorchtransformerattention-model

解决方案


用于屏蔽填充的key_padding_mask位置,即在输入序列结束之后。这始终特定于输入批次,并且取决于批次中的序列与最长的序列相比有多长。它是形状批量大小×输入长度的二维张量。

另一方面,attn_mask说明哪些键值对是有效的。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这是att_mask通常使用的。如果是二维张量,则形状为输入长度×输入长度。您还可以有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状(batch size × num Heads) × input length × input length的 3D 张量。(因此,理论上,您可以key_padding_mask使用 3D进行模拟att_mask。)


推荐阅读