deep-learning - pytorch 中的交叉熵损失如何工作?
问题描述
我正在尝试一些 pytorch 代码。通过交叉熵损失,我发现了一些有趣的结果,并且我使用了 pytorch 的二元交叉熵损失和交叉熵损失。
import torch
import torch.nn as nn
X = torch.tensor([[1,0],[1,0],[0,1],[0,1]],dtype=torch.float)
softmax = nn.Softmax(dim=1)
bce_loss = nn.BCELoss()
ce_loss= nn.CrossEntropyLoss()
pred = softmax(X)
bce_loss(X,X) # tensor(0.)
bce_loss(pred,X) # tensor(0.3133)
bce_loss(pred,pred) # tensor(0.5822)
ce_loss(X,torch.argmax(X,dim=1)) # tensor(0.3133)
我预计相同输入和输出的交叉熵损失为零。这里 X, pred 和 torch.argmax(X,dim=1) 是相同/相似的一些变换。这种推理仅适用于bce_loss(X,X) # tensor(0.)
其他所有导致损失大于零的情况。我推测 , 的输出bce_loss(pred,X)
应该bce_loss(pred,pred)
为零ce_loss(X,torch.argmax(X,dim=1))
。
这里有什么错误?
解决方案
您看到这个的原因是因为nn.CrossEntropyLoss
接受 logits 和目标,也就是 X 应该是 logits,但已经在 0 和 1 之间。X
应该更大,因为在 softmax 之后它将在 0 和 1 之间。
ce_loss(X * 1000, torch.argmax(X,dim=1)) # tensor(0.)
nn.CrossEntropyLoss
与 logits 一起使用,以利用 log sum 技巧。
激活后您当前尝试的方式,您的预测将变为大约[0.73, 0.26]
.
二进制交叉熵示例有效,因为它接受已激活的 logits。顺便说一句,您可能想使用nn.Sigmoid
激活二元交叉熵 logits。对于 2-class 示例,softmax 也可以。
推荐阅读
- python - 使用 Joblib 并行化将 python 脚本提交到 Sun Grid Engine
- reactjs - 在 React 中,setState 回调和在状态变量上使用 useEffect 有什么区别?
- datatable - 如何修复闪亮应用程序中 data.table 标题的位置
- javascript - 我正在寻找替代方法来比较 2 个数组,找到相等和重新分配
- python - 如何跨模块保留全局变量?
- javascript - 文档获取元素未在我的开始游戏功能中执行
- python - 转换为 df 列到日期时间 - 提高 SettingWithCopyWarning
- html - 为什么将鼠标悬停在主 div 上会显示下拉菜单但按钮不显示?
- cassandra - Cassandra (Datastax) CQL 忽略 TEXT 列的大小写
- ios - 从平台方法调用返回布尔值