deep-learning - 如何在 PyTorch 中计算自举交叉熵损失?
问题描述
我读过一些论文,它们使用一种叫做“自举交叉熵损失”的东西来训练他们的分割网络。这个想法是只关注最难的 k%(比如 15%)像素,以提高学习性能,尤其是当简单像素占主导地位时。
目前,我正在使用标准交叉熵:
loss = F.binary_cross_entropy(mask, gt)
如何在 PyTorch 中有效地将其转换为引导版本?
解决方案
通常我们还会在损失中添加一个“热身”期,以便网络可以学习首先适应容易的区域并过渡到较难的区域。
这个实现从k=100
20000 次迭代开始并持续,然后线性衰减到k=15
另外 50000 次迭代。
class BootstrappedCE(nn.Module):
def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self, input, target, it):
if it < self.start_warm:
return F.cross_entropy(input, target), 1.0
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return loss.mean(), this_p
推荐阅读
- tensorflow - 有没有办法构建一个以指定角度随机旋转的 keras 预处理层?
- r - 使用 Rstudio 抓取亚马逊评论时出错:参数暗示不同的行数:3、10
- java - Android:NullPointerException:尝试在空对象引用上调用接口方法
- javascript - 如何在 React-Big-Calender 中创建自定义视图?
- django - Django Authenticate 方法不适用于扩展“AbstractUser”的用户类
- postgresql - 将旧数据恢复到新版本的 Postgres
- c++ - 缺少 ws2def.h,我可以从哪个包下载?
- javascript - 在反应中动画文本状态变化
- java - Android 与 int 类型的双向绑定
- python-3.x - 如何使用 Cell 方法通过 python 从 excel 单元格中检索值?,使用 python3.9 和 openpyxl