首页 > 解决方案 > 计算 rmse 时 Pytorch 掩码缺失值

问题描述

我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为 0 的行(缺失值)。我如何修改这一行以考虑到该限制?

torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()

先感谢您。

标签: pytorch

解决方案


这可以通过定义自定义 MSE 损失函数* 来解决,该函数从输入和目标张量中屏蔽缺失值(在您的情况下为 0):

def mse_loss_with_nans(input, target):

    # Missing data are nan's
    # mask = torch.isnan(target)

    # Missing data are 0's
    mask = target == 0

    out = (input[~mask]-target[~mask])**2
    loss = out.mean()

    return loss

(*)从优化的角度来看,计算 MSE 等同于 RMSE——具有计算速度更快的优势。


推荐阅读