首页 > 解决方案 > 试图用另一个相同维度的张量来掩盖张量,得到“索引 1 超出尺寸 1 的维度 0 的范围”

问题描述

    attn_weights = F.softmax(self.attn(torch.cat((input, hidden_cat), 2)), dim=2)
    attn_weights[mask] = float('-inf')
    attn_applied = torch.bmm(attn_weights.transpose(0,1),encoder_outputs.transpose(0,1)).transpose(0,1)
    attn_output = torch.cat((input, attn_applied), 2)

所以我试图将掩码中所有等于 1 的索引设置为负无穷大,但是那条线

attn_weights[mask] = float('-inf')

不断抛出此异常“索引 1 超出尺寸 1 的维度 0 的范围”不确定发生了什么 attn_weights 和掩码都具有相同的维度,即 1 x 2048 x 40。

标签: pytorchmaskattention-model

解决方案


原来掩码张量的 dtype 必须是 torch.uint8 或 torch.bool 我有它 torch.long


推荐阅读