python - 二元交叉熵计算中的 pos_weight
问题描述
当我们处理不平衡的训练数据(负样本多,正样本少)时,通常pos_weight
会使用参数。的期望pos_weight
是,当positive sample
得到错误的标签时,模型将比negative sample
. 当我使用该binary_cross_entropy_with_logits
功能时,我发现:
bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong = torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
然而:
>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)
pos_weight
错误的正样本和负样本产生的损失是一样的,那么在不平衡数据损失计算中是如何工作的呢?
解决方案
TLDR;两个损失是相同的,因为您正在计算相同的数量:两个输入是相同的,两个批次元素和标签只是交换了。
为什么你会得到同样的损失?
我认为您对F.binary_cross_entropy_with_logits
(您可以找到更详细的文档页面nn.BCEWithLogitsLoss
)的使用感到困惑。在您的情况下,您的输入形状(也就是模型的输出)是一维的,这意味着您只有一个 logit x
,而不是两个)。
在您的示例中,您有
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
这意味着您的批量大小是2
,并且由于默认情况下该函数是对批量元素的损失进行平均,因此您最终会得到与 和 相同的BCE(preds_pos_wrong, label_pos)
结果BCE(preds_neg_wrong, label_neg)
。您的批次的两个元素刚刚切换。
您可以通过以下选项不平均批次元素的损失来非常轻松地验证这一点reduction='none'
:
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])
调查F.binary_cross_entropy_with_logits
:
话虽如此,二元交叉熵的公式是:
bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
其中y
( 分别sigmoid(x)
是与该 logit 相关的正类,1 - y
(resp. 1 - sigmoid(x)
) 是负类。
文档可能更精确地说明权重方案pos_weight
(不要与 混淆weight
,后者是不同 logits 输出的权重)。正如你所说,这个想法pos_weight
是衡量积极的术语,而不是整个术语。
bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
正项的权重在哪里w_p
,以补偿正负样本的不平衡。在实践中,这应该是w_p = #positive/#negative
。
所以:
>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])
使用内置的损失函数,
>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])
与人工计算相比:
>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])
推荐阅读
- java - 找不到匹配的 bean - `dispatcher-servlet.xml` 文件中的 Spring 错误
- svelte - 嵌套布局上的 Svelte Routing 404
- python - 如何根据条件加入数组元素?
- laravel - Laravel foreach 没有关闭
- scala - 未设置线性回归特征
- sql - PostgreSQL SQL 的触发器和重复问题
- security - ArcGIS JavaScript API 中的参数化查询
- vb.net - 画布内文本中的 iText 垂直对齐不起作用
- reactjs - Typescript React/Nextjs styled-components - 将函数作为道具传递给组件会在缺少类型时引发错误
- python - 展平层输出与输入形状不匹配