python - 试图理解 PyTorch SmoothL1Loss 的实现
问题描述
我一直在尝试通过 PyTorch 中的所有损失函数并从头开始构建它们以更好地理解它们,并且我遇到了我的娱乐问题或 PyTorch 实现的问题。
根据 Pytorch 的 SmoothL1Loss 文档,它简单地说,如果预测的绝对值减去基本事实小于 beta,我们使用顶部方程。否则,我们使用底部的。请参阅方程式的文档。
以下是我以最小测试的形式对此的实现:
import torch
import torch.nn as nn
import numpy as np
predictions = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
def l1_loss_smooth(predictions, targets, beta = 1.0):
loss = 0
for x, y in zip(predictions, targets):
if abs(x-y).mean() < beta:
loss += (0.5*(x-y)**2 / beta).mean()
else:
loss += (abs(x-y) - 0.5 * beta).mean()
loss = loss/predictions.shape[0]
output = l1_loss_smooth(predictions, target)
print(output)
Gives an output of:
tensor(0.7475, grad_fn=<DivBackward0>)
现在 Pytorch 实现:
loss = nn.SmoothL1Loss(beta=1.0)
output = loss(predictions, target)
Gives an output of:
tensor(0.7603, grad_fn=<SmoothL1LossBackward>)
我无法弄清楚实施中的错误在哪里。
在深入了解smooth_l1_loss function
模块_C
(文件:)后smooth_c_loss_op.cc
,我注意到文档字符串提到它是 Huber Loss 的变体,但文档SmoothL1Loss
说它是 Huber Loss。
所以总的来说,只是对它的实现方式以及它是否是 SmoothL1Loss 和 Huber Loss、Just Huber Loss 或其他东西的组合感到困惑。
解决方案
文档中的描述是正确的。您的实现错误地将案例选择应用于数据的平均值。相反,它应该是一个元素选择(如果你考虑普通 L1 损失的实现,以及平滑 L1 损失的动机)。
以下代码给出了一致的结果:
import torch
import torch.nn as nn
import numpy as np
predictions = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
def l1_loss_smooth(predictions, targets, beta = 1.0):
loss = 0
diff = predictions-targets
mask = (diff.abs() < beta)
loss += mask * (0.5*diff**2 / beta)
loss += (~mask) * (diff.abs() - 0.5*beta)
return loss.mean()
output = l1_loss_smooth(predictions, target)
print(output)
loss = nn.SmoothL1Loss(beta=1.0)
output = loss(predictions, target)
print(output)
推荐阅读
- sql-server - 选择列表中只能指定单个表达式,MSSQL查询
- heroku - React-BoilerPlate 部署到 heroku 成功但应用程序错误
- visual-studio - Service Fabric 应用程序模板在 Visual Studio 中不可用
- javascript - php和javascript替换正则表达式
- json - 删除/替换 Azure 逻辑应用中解析的 JSON 输出中的字符
- typo3 - 部署脚本无法生成 PackageStates.php,而相同的 Shell-Command 可以工作。(TYPO3)
- python-3.x - 通过 Python 连接到 Sharepoint - UrlLib 错误
- r - 按序列将向量拆分为组
- css - 旋转对象并从顶部对其进行缩放
- microsoft-graph-api - Microsoft 图形中的法语口音