pytorch - 计算 rmse 时 Pytorch 掩码缺失值
问题描述
我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为 0 的行(缺失值)。我如何修改这一行以考虑到该限制?
torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
先感谢您。
解决方案
这可以通过定义自定义 MSE 损失函数* 来解决,该函数从输入和目标张量中屏蔽缺失值(在您的情况下为 0):
def mse_loss_with_nans(input, target):
# Missing data are nan's
# mask = torch.isnan(target)
# Missing data are 0's
mask = target == 0
out = (input[~mask]-target[~mask])**2
loss = out.mean()
return loss
(*)从优化的角度来看,计算 MSE 等同于 RMSE——具有计算速度更快的优势。
推荐阅读
- rtf - RTF 表格标题行
- javascript - 如何修复 TypeError:application_module__WEBPACK_IMPORTED_MODULE_1___default.a 不是构造函数
- python - 从打印的变量中删除字符
- javascript - 如何在页面中显示我从用户那里获得的信息?
- javascript - 需要在 JS 文件中包含 HTML 代码
- aspose - 在 HTML 页面中重复一个部分
- excel - 获取从VBA中的段落获得的行的第一个单词的索引号
- assembly - 我如何为 Risc-V(Assembly Language) 编写旋转操作,我们是否有任何命令,就像我们在 8086 中一样?
- scala - Scala 用户定义的注解类的构造函数什么时候执行?
- angularjs - 如何修复 AngularJS 中未定义的承诺错误