首页 > 解决方案 > NLLLoss 只是一个正常的负函数?

问题描述

我很难理解nn.NLLLoss()

由于下面的代码总是打印,那么使用负号 (-) 和使用负号 (-)有什么True区别?nn.NLLLoss()

import torch
while 1:
   b = torch.randn(1)
   print(torch.nn.NLLLoss()(b, torch.tensor([0])) == -b[0])

标签: pythonmachine-learningdeep-learningpytorchloss-function

解决方案


在您的情况下,每个批次元素只有一个输出值,目标是0. 损失将nn.NLLLoss选择与目标张量中包含的索引相对应的预测张量的值。这是一个更一般的示例,您总共有五个批处理元素,每个批处理元素具有三个 logit 值:

>>> logits = torch.randn(5, 3, requires_grad=True)
>>> y = torch.tensor([1, 0, 2, 0, 1])
>>> y_hat = torch.softmax(b, -1)

张量yy_hat分别对应于目标张量和估计分布。您可以nn.NLLLoss使用以下方法实现:

>>> -y_hat[torch.arange(len(y_hat)), y]
tensor([-0.2195, -0.1015, -0.3699, -0.5203, -0.1171], grad_fn=<NegBackward>)

与内置函数相比:

>>> F.nll_loss(y_hat, y, reduction='none')
tensor([-0.2195, -0.1015, -0.3699, -0.5203, -0.1171], grad_fn=<NllLossBackward>)

-y_hat这与独自一人完全不同。


推荐阅读