首页 > 解决方案 > src_mask 和 src_key_padding_mask 的区别

问题描述

我很难理解变形金刚。一切都变得越来越清晰,但让我头疼的一件事是 src_mask 和 src_key_padding_mask 之间的区别,它在编码器层和解码器层的前向函数中作为参数传递。

https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer

标签: pytorchtransformer

解决方案


src_mask 和 src_key_padding_mask 的区别

一般的事情是注意使用张量_mask_key_padding_mask. 在注意完成后,在转换器内部,我们通常会得到一个平方中间张量,其中包含所有大小比较[Tx, Tx](对于编码器的输入)、[Ty, Ty](对于移位的输出 - 解码器的输入之一)和[Ty, Tx](对于内存掩码 -编码器/存储器的输出与解码器/移位输出的输入之间的注意力)。

所以我们知道这是转换器中每个掩码的用途(请注意 pytorch 文档中的符号如下Tx=S is the source sequence length (例如输入批次的最大值), Ty=T is the target sequence length(例如目标长度的最大值) B=N is the batch size,, D=E is the feature number):

  1. src_mask [Tx, Tx] = [S, S]– src 序列的附加掩码(可选)。这是在做的时候应用的atten_src + src_mask。我不确定示例输入 - 请参阅 tgt_mask 示例,但典型用途是添加-inf,以便可以根据需要以这种方式屏蔽 src_attention。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

  2. tgt_mask [Ty, Ty] = [T, T]– tgt 序列的附加掩码(可选)。这是在做的时候应用的atten_tgt + tgt_mask。一个示例使用是避免解码器作弊的对角线。所以 tgt 是右移的,第一个令牌是嵌入 SOS/BOS 的序列令牌的开始,因此第一个条目为零而其余条目。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

  3. memory_mask [Ty, Tx] = [T, S]– 编码器输出的附加掩码(可选)。这是在做的时候应用的atten_memory + memory_mask。不确定使用示例,但如前所述,添加-inf将一些注意力权重设置为零。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

  4. src_key_padding_mask [B, Tx] = [N, S]– 每批 src 键的 ByteTensor 掩码(可选)。由于您的 src 通常具有不同长度的序列,因此通常会删除您在末尾附加的填充向量。为此,您可以指定批次中每个示例的每个序列的长度。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

  5. tgt_key_padding_mask [B, Ty] = [N, t]– 每批 tgt 键的 ByteTensor 掩码(可选)。和以前一样。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

  6. memory_key_padding_mask [B, Tx] = [N, S]– 每批内存键的 ByteTensor 掩码(可选)。和以前一样。具体例子见附录。如果提供了 ByteTensor,则不允许非零位参加,而零位将保持不变。如果提供了 BoolTensor,则不允许出现 True 的位置,而 False 值将保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。

附录

pytorch 教程中的示例(https://pytorch.org/tutorials/beginner/translation_transformer.html):

1 src_mask 示例

    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

返回大小为布尔值的张量[Tx, Tx]

tensor([[False, False, False,  ..., False, False, False],
         ...,
        [False, False, False,  ..., False, False, False]])

2 tgt_mask 示例

    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)
    mask = mask.transpose(0, 1).float()
    mask = mask.masked_fill(mask == 0, float('-inf'))
    mask = mask.masked_fill(mask == 1, float(0.0))

为解码器的输入生成右移输出的对角线。

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
         ...,
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])

通常右移输出的开头有 BOS/SOS,教程只需在前面附加 BOS/SOS,然后用 . 修剪最后一个元素即可获得右移tgt_input = tgt[:-1, :]

3 _填充

填充只是为了掩盖最后的填充。src 填充通常与内存填充相同。tgt 有它自己的序列,因此它有自己的填充。例子:

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    memory_padding_mask = src_padding_mask

输出:

tensor([[False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ...,  True,  True,  True]])

请注意,aFalse表示那里没有填充标记(所以是的,在变压器前向传递中使用该值),并且 aTrue表示存在填充标记(因此将其屏蔽,因此变压器前向传递不会受到影响)。


答案有点分散,但我发现只有这 3 个参考有用(诚实的单独层文档/东西不是很有用):


推荐阅读