python - MultiHeadAttnetion 中的 att_mask 和 key_padding_mask 有什么区别
问题描述
pytorchatt_mask
和key_padding_mask
in有什么区别:MultiHeadAttnetion
key_padding_mask – 如果提供,key 中指定的填充元素将被注意力忽略。当给定一个二进制掩码并且值为 True 时,注意力层上的相应值将被忽略。当给定一个字节掩码并且一个值不为零时,注意力层上的相应值将被忽略
attn_mask – 2D 或 3D 掩码,可防止对某些位置的注意。将为所有批次广播 2D 掩码,而 3D 掩码允许为每个批次的条目指定不同的掩码。
提前致谢。
解决方案
用于屏蔽填充的key_padding_mask
位置,即在输入序列结束之后。这始终特定于输入批次,并且取决于批次中的序列与最长的序列相比有多长。它是形状批量大小×输入长度的二维张量。
另一方面,attn_mask
说明哪些键值对是有效的。在 Transformer 解码器中,三角形掩码用于模拟推理时间并防止关注“未来”位置。这是att_mask
通常使用的。如果是二维张量,则形状为输入长度×输入长度。您还可以有一个特定于批次中每个项目的掩码。在这种情况下,您可以使用形状(batch size × num Heads) × input length × input length的 3D 张量。(因此,理论上,您可以key_padding_mask
使用 3D进行模拟att_mask
。)
推荐阅读
- python - Pandas 读取 sharepoint excel 文件:XLRDError
- python - 具有多个蓝图路由文件的 Flask 应用程序
- javascript - e 是递归处理函数的未定义的正确语法
- reactjs - 无法使用 setState 函数更新 Zusand 状态
- reactjs - 如何使用 react-datetime 选择器仅将日期传递给数据库(不传递时间)
- heroku - 为什么我的 Heroku 托管 discord bot 没有上线?
- node.js - 找不到命名空间“FirebaseFirestore”-Node.js、Express、Typescript
- javascript - React js 中的路由保护无法正常工作
- flutter-layout - 颤动中的“RenderBox未布置”文本字段错误
- python - 如何在由 1D NumPy 数组(Python)中的值表示的索引处获取值为 1 的 2D NumPy 数组