首页 > 解决方案 > 如何将掩码纳入负似然损失 (torch.nn.functional.nll_loss)

问题描述

您好,我正在为家庭作业实施语言建模的 lstm,我正处于亏损实施阶段。我们的教练告诉我们使用 F.nll_loss 但序列被填充,我们必须考虑给定的掩码,它告诉我们序列何时停止。

输入:

在不考虑掩码的情况下工作的幼稚实现:

import torch.nn.functional as F
loss = F.nll_loss(log_probas.transpose(1, 2), targets)

我一直在上网并敲打我的头,但似乎无法找到如何将掩码纳入损失的平均方案的答案。

标签: pythondeep-learningpytorchtorchlanguage-model

解决方案


您可以重塑张量并使用掩码来选择未填充的标记,并计算损失

vocab_size = log_probas.size(-1)
log_probas = log_probas.view(-1, vocab_size)
target = target.view(-1)
mask = mask.view(-1).bool()
loss = F.nll_loss(log_probas[mask], targets[mask])

推荐阅读