首页 > 解决方案 > 需要使用 if 语句的自定义损失函数

问题描述

我正在尝试训练输出 3 个值的 DNN,(x,y,z)其中xy是我正在寻找的对象的坐标,并且z是对象存在的概率

我需要自定义损失函数:

如果z_true<0.5我不关心xy值,那么错误应该等于(0, 0, sqr(z_true - z_pred))

否则错误应该像(sqr(x_true - x_pred), sqr(y_true - y_pred), sqr(z_true - z_pred))

我正在努力将张量和 if 语句混合在一起。

标签: tensorflowmachine-learningkerasloss-function

解决方案


也许这个自定义损失函数的例子会让你启动并运行。它展示了如何将张量与 if 语句混合。

 def conditional_loss_function(l):
        def loss(y_true, y_pred):
            if l == 0: 
                return loss_funtion1(y_true, y_pred)
            else: 
                return loss_funtion2(y_true, y_pred)
        return loss

 model.compile(loss=conditional_loss_function(l), optimizer=...)

推荐阅读