首页 > 解决方案 > Keras 自定义损失函数 - 生存分析截尾

问题描述

我正在尝试在 Keras 中构建一个自定义损失函数 - 这用于生存分析中的审查数据。

这个损失函数本质上是二元交叉熵,即多标签分类,但是损失函数中的求和项需要根据 Y_true 中标签的可用性而变化。请参见下面的示例:

示例 1:适用于 Y_True 的所有标签

Y_true = [0, 0, 0, 1, 1]

Y_pred = [0.1, 0.2, 0.2, 0.8, 0.7]

损失 = -1/5(log(0.9) + log(0.8) + log(0.8) + log(0.8) + log(0.7)) = 0.22

示例 2:只有两个标签可用于 Y_True

Y_true = [0, 0, -, -, -]

Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]

损失 = -1/2 (log(0.9) + log(0.8)) = 0.164

示例 3:只有一个标签可用于 Y_True

Y_true = [0, -, -, -, -]

Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]

损失 = -1 (log(0.9)) = 0.105

在此处输入图像描述

在示例一的情况下,我们的损失将通过上面的公式计算,K = 5。在示例二中,我们的损失将通过 K = 2 计算(即仅根据基本事实中可用的前两项进行评估)。损失函数需要根据 Y_true 可用性进行调整。

我尝试过自定义 Keras 损失函数......但是我正在努力研究如何基于张量流中的 nan 索引进行过滤。有人对上述自定义损失函数的编码有什么建议吗?

def nan_binary_cross_entropy(y_actual, y_predicted):
    stack = tf.stack((tf.is_nan(y_actual), tf.is_nan(y_predicted)),axis=1)
    is_nans = K.any(stack, axis=1)
    per_instance = tf.where(is_nans, tf.zeros_like(y_actual), 
                            tf.square(tf.subtract(y_predicted, y_actual)))

    FILTER HERE
    return K.binary_crossentropy(y_filt, y_filt), axis=-1)

标签: pythontensorflowkeras

解决方案


您可以使用tf.math.is_nan和的组合tf.math.multiply_no_nan来掩盖您的y_true以获得所需的结果。

import numpy as np
import tensorflow as tf


y_true = tf.constant([
    [0.0, 0.0, 0.0, 1.0, 1.0],
    [0.0, 0.0, np.nan, np.nan, np.nan],
    [0.0, np.nan, np.nan, np.nan, np.nan],
])


y_pred = tf.constant([
    [0.1, 0.2, 0.2, 0.8, 0.7],
    [0.1, 0.2, 0.1, 0.9, 0.9],
    [0.1, 0.2, 0.1, 0.9, 0.9],
])


def survival_loss_fn(y_true, y_pred):
    # create a mask for NaN elements
    mask = tf.cast(~tf.math.is_nan(y_true), tf.float32)
    # sum along the row axis of the mask to find the `N`
    # for each training instance
    Ns = tf.math.reduce_sum(mask, 1)
    # use `multiply_no_nan` to zero out the NaN in `y_pred`
    fst = tf.math.multiply_no_nan(y_true, mask) * tf.math.log(y_pred)
    snd = tf.math.multiply_no_nan(1.0 - y_true, mask) * tf.math.log(1.0 - y_pred)
    return -tf.math.reduce_sum(fst + snd, 1) / Ns


survival_loss_fn(y_true, y_pred)
# <tf.Tensor: shape=(3,), [0.22629324, 0.16425204, 0.10536055], dtype=float32)>

推荐阅读