首页 > 解决方案 > pytorch 中的交叉熵损失如何工作?

问题描述

我正在尝试一些 pytorch 代码。通过交叉熵损失,我发现了一些有趣的结果,并且我使用了 pytorch 的二元交叉熵损失和交叉熵损失。

import torch
import torch.nn as nn

X = torch.tensor([[1,0],[1,0],[0,1],[0,1]],dtype=torch.float)
softmax = nn.Softmax(dim=1)


bce_loss = nn.BCELoss()
ce_loss= nn.CrossEntropyLoss()

pred = softmax(X)

bce_loss(X,X) # tensor(0.)
bce_loss(pred,X) # tensor(0.3133)
bce_loss(pred,pred) # tensor(0.5822)

ce_loss(X,torch.argmax(X,dim=1)) # tensor(0.3133)

我预计相同输入和输出的交叉熵损失为零。这里 X, pred 和 torch.argmax(X,dim=1) 是相同/相似的一些变换。这种推理仅适用于bce_loss(X,X) # tensor(0.)其他所有导致损失大于零的情况。我推测 , 的输出bce_loss(pred,X)应该bce_loss(pred,pred)为零ce_loss(X,torch.argmax(X,dim=1))

这里有什么错误?

标签: deep-learningpytorchloss-functioncross-entropy

解决方案


您看到这个的原因是因为nn.CrossEntropyLoss接受 logits 和目标,也就是 X 应该是 logits,但已经在 0 和 1 之间。X应该更大,因为在 softmax 之后它将在 0 和 1 之间。

ce_loss(X * 1000, torch.argmax(X,dim=1)) # tensor(0.)

nn.CrossEntropyLoss与 logits 一起使用,以利用 log sum 技巧。

激活后您当前尝试的方式,您的预测将变为大约[0.73, 0.26].

二进制交叉熵示例有效,因为它接受已激活的 logits。顺便说一句,您可能想使用nn.Sigmoid激活二元交叉熵 logits。对于 2-class 示例,softmax 也可以。


推荐阅读