python - Keras 自定义损失函数 - 生存分析截尾
问题描述
我正在尝试在 Keras 中构建一个自定义损失函数 - 这用于生存分析中的审查数据。
这个损失函数本质上是二元交叉熵,即多标签分类,但是损失函数中的求和项需要根据 Y_true 中标签的可用性而变化。请参见下面的示例:
示例 1:适用于 Y_True 的所有标签
Y_true = [0, 0, 0, 1, 1]
Y_pred = [0.1, 0.2, 0.2, 0.8, 0.7]
损失 = -1/5(log(0.9) + log(0.8) + log(0.8) + log(0.8) + log(0.7)) = 0.22
示例 2:只有两个标签可用于 Y_True
Y_true = [0, 0, -, -, -]
Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]
损失 = -1/2 (log(0.9) + log(0.8)) = 0.164
示例 3:只有一个标签可用于 Y_True
Y_true = [0, -, -, -, -]
Y_pred = [0.1, 0.2, 0.1, 0.9, 0.9]
损失 = -1 (log(0.9)) = 0.105
在示例一的情况下,我们的损失将通过上面的公式计算,K = 5。在示例二中,我们的损失将通过 K = 2 计算(即仅根据基本事实中可用的前两项进行评估)。损失函数需要根据 Y_true 可用性进行调整。
我尝试过自定义 Keras 损失函数......但是我正在努力研究如何基于张量流中的 nan 索引进行过滤。有人对上述自定义损失函数的编码有什么建议吗?
def nan_binary_cross_entropy(y_actual, y_predicted):
stack = tf.stack((tf.is_nan(y_actual), tf.is_nan(y_predicted)),axis=1)
is_nans = K.any(stack, axis=1)
per_instance = tf.where(is_nans, tf.zeros_like(y_actual),
tf.square(tf.subtract(y_predicted, y_actual)))
FILTER HERE
return K.binary_crossentropy(y_filt, y_filt), axis=-1)
解决方案
您可以使用tf.math.is_nan
和的组合tf.math.multiply_no_nan
来掩盖您的y_true
以获得所需的结果。
import numpy as np
import tensorflow as tf
y_true = tf.constant([
[0.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, np.nan, np.nan, np.nan],
[0.0, np.nan, np.nan, np.nan, np.nan],
])
y_pred = tf.constant([
[0.1, 0.2, 0.2, 0.8, 0.7],
[0.1, 0.2, 0.1, 0.9, 0.9],
[0.1, 0.2, 0.1, 0.9, 0.9],
])
def survival_loss_fn(y_true, y_pred):
# create a mask for NaN elements
mask = tf.cast(~tf.math.is_nan(y_true), tf.float32)
# sum along the row axis of the mask to find the `N`
# for each training instance
Ns = tf.math.reduce_sum(mask, 1)
# use `multiply_no_nan` to zero out the NaN in `y_pred`
fst = tf.math.multiply_no_nan(y_true, mask) * tf.math.log(y_pred)
snd = tf.math.multiply_no_nan(1.0 - y_true, mask) * tf.math.log(1.0 - y_pred)
return -tf.math.reduce_sum(fst + snd, 1) / Ns
survival_loss_fn(y_true, y_pred)
# <tf.Tensor: shape=(3,), [0.22629324, 0.16425204, 0.10536055], dtype=float32)>
推荐阅读
- python - 如何在没有 tk 窗口的情况下在 tkinter 中设置消息框的坐标
- hyperledger-fabric - 我试图在我的本地网络中部署外部链代码。但我收到此错误链代码注册失败:容器以 0 退出”
- prometheus - 我可以将查询/过滤器应用于仪表板级别的所有面板查询吗?
- node.js - Mongoose 不更新空字段
- swift - SwiftUI `navigationBarItems` 与 `toolbar` 有什么区别?
- excel - VBA Excel 将指定范围的工作表导出为 PDF
- ios - 只接收活动场景阶段,它不会继续 XCode 12.4 IOS 14
- c# - C# Azure 存储队列 - 能写、不能读和奇怪的东西
- python - 使用 Python 在 XPATH / Selenium 中查找最后一个 div
- reactjs - (React with Webpack) react-router-dom 模块在开发中工作,但在部署时不起作用