pytorch - src_mask 和 src_key_padding_mask 的区别
问题描述
我很难理解变形金刚。一切都变得越来越清晰,但让我头疼的一件事是 src_mask 和 src_key_padding_mask 之间的区别,它在编码器层和解码器层的前向函数中作为参数传递。
https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer
解决方案
src_mask 和 src_key_padding_mask 的区别
一般的事情是注意使用张量_mask
与_key_padding_mask
. 在注意完成后,在转换器内部,我们通常会得到一个平方中间张量,其中包含所有大小比较[Tx, Tx]
(对于编码器的输入)、[Ty, Ty]
(对于移位的输出 - 解码器的输入之一)和[Ty, Tx]
(对于内存掩码 -编码器/存储器的输出与解码器/移位输出的输入之间的注意力)。
所以我们知道这是转换器中每个掩码的用途(请注意 pytorch 文档中的符号如下Tx=S is the source sequence length
(例如输入批次的最大值),
Ty=T is the target sequence length
(例如目标长度的最大值)
B=N is the batch size
,,
D=E is the feature number
):
src_mask
[Tx, Tx] = [S, S]
– src 序列的附加掩码(可选)。这是在做的时候应用的atten_src + src_mask
。我不确定示例输入 - 请参阅 tgt_mask 示例,但典型用途是添加-inf
,以便可以根据需要以这种方式屏蔽 src_attention。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。tgt_mask
[Ty, Ty] = [T, T]
– tgt 序列的附加掩码(可选)。这是在做的时候应用的atten_tgt + tgt_mask
。一个示例使用是避免解码器作弊的对角线。所以 tgt 是右移的,第一个令牌是嵌入 SOS/BOS 的序列令牌的开始,因此第一个条目为零而其余条目。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。memory_mask
[Ty, Tx] = [T, S]
– 编码器输出的附加掩码(可选)。这是在做的时候应用的atten_memory + memory_mask
。不确定使用示例,但如前所述,添加-inf
将一些注意力权重设置为零。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。src_key_padding_mask
[B, Tx] = [N, S]
– 每批 src 键的 ByteTensor 掩码(可选)。由于您的 src 通常具有不同长度的序列,因此通常会删除您在末尾附加的填充向量。为此,您可以指定批次中每个示例的每个序列的长度。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。tgt_key_padding_mask
[B, Ty] = [N, t]
– 每批 tgt 键的 ByteTensor 掩码(可选)。和以前一样。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。memory_key_padding_mask
[B, Tx] = [N, S]
– 每批内存键的 ByteTensor 掩码(可选)。和以前一样。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。
附录
pytorch 教程中的示例(https://pytorch.org/tutorials/beginner/translation_transformer.html):
1 src_mask 示例
src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
返回大小为布尔值的张量[Tx, Tx]
:
tensor([[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False]])
2 tgt_mask 示例
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)
mask = mask.transpose(0, 1).float()
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, float(0.0))
为解码器的输入生成右移输出的对角线。
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
...,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.]])
通常右移输出的开头有 BOS/SOS,教程只需在前面附加 BOS/SOS,然后用 . 修剪最后一个元素即可获得右移tgt_input = tgt[:-1, :]
。
3 _填充
填充只是为了掩盖最后的填充。src 填充通常与内存填充相同。tgt 有它自己的序列,因此它有自己的填充。例子:
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
memory_padding_mask = src_padding_mask
输出:
tensor([[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., True, True, True]])
请注意,aFalse
表示那里没有填充标记(所以是的,在变压器前向传递中使用该值),并且 aTrue
表示存在填充标记(因此将其屏蔽,因此变压器前向传递不会受到影响)。
答案有点分散,但我发现只有这 3 个参考有用(诚实的单独层文档/东西不是很有用):
推荐阅读
- visual-studio - 禁用属性/参数名称信息 VS
- c - 对于结构无法正常工作,不在非素数后打印短语
- python - 使用 asyncio 在后台运行方法
- reactjs - 如何使用 React 测试库测试响应式 React 组件
- google-chrome - 如何使用 Chrome Incognito 运行 Cypress 为首的测试
- .net - .NET Interactive Notebooks: How do I overwrite a nuget package version?
- dataframe - 将字符串转换为 pyspark.sql.types.StructType pyspark
- swift - 我们可以在 swift 中显示两个固定高度的 CollectionView 部分吗?
- r - 根据R中的另一个变量计算多个日期差异
- sql - SQL - 用计算值替换 NULL 值