tensorflow - 自定义 Keras Metrics Class -> Metric 在某个召回值
问题描述
我正在尝试构建一个与 metrics.PrecisionAtRecall 类相当的指标。因此,我尝试通过扩展 keras.metrics.Metric 类来构建自定义指标。
原始函数是WSS = (TN + FN)/N - 1 + TP/(TP + FN),这应该在某个召回值下计算,例如 95%。
到目前为止,我所拥有的是以下内容:
class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
self.wss = self.add_weight(name='wss', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
y_pred_neg = 1 - y_pred_pos
y_pos = tf.cast(backend.round(backend.clip(y_true, 0, 1)), tf.float32)
y_neg = 1 - y_pos
fn = backend.sum(y_neg * y_pred_pos)
tn = backend.sum(y_neg * y_pred_neg)
tp = backend.sum(y_pos * y_pred_pos)
n = len(y_true) # number of studies in batch
r = tp/(tp+fn+backend.epsilon()) # recall
self.wss.assign(((tn+fn)/n)-(1+r))
def result(self):
return self.wss
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.wss.assign(0.)
如何计算某个召回时的 WSS?我在 tensorflow 自己的 git 存储库中看到了以下内容:
def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
value=recall,
num_thresholds=num_thresholds,
name=name,
dtype=dtype)
但这通过 keras.metrics.Metric 类是不可能的
解决方案
如果我们遵循本文给出的 WSS@95 的定义:Reducing Workload in Systematic Review Preparation Using Automated Citation Classification,那么我们有
对于目前的工作,我们将召回率固定为 0.95,因此在 95% 召回率 (WSS@95%) 的抽样中节省的工作是:
您可以通过以下方式定义更新功能:
class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
self.wss = self.add_weight(name='wss', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
y_pred_neg = 1 - y_pred_pos
y_neg = 1 - y_pos
fn = backend.sum(y_neg * y_pred_pos)
tn = backend.sum(y_neg * y_pred_neg)
n = len(y_true) # number of studies in batch
self.wss.assign(((tn+fn)/n)-(1-self.recall))
另一种解决方案是从 tensorflow 类扩展SensitivitySpecificityBase
并在实现 PresicionAtRecall 类时实现 WSS。
通过使用此类,以下是 WSS 的计算方式:
- 计算所有阈值的召回率(默认为 200 个阈值)。
- 找到召回率最接近请求值的阈值索引。(在这种情况下为 0.95)。
- 计算该索引处的 WSS。
阈值的数量用于匹配给定的召回率。
import tensorflow as tf
from tensorflow.python.keras.metrics import SensitivitySpecificityBase
class WorkSavedOverSamplingAtRecall(SensitivitySpecificityBase):
def __init__(self, recall, num_thresholds=200, name="wss_at_recall", dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(WorkSavedOverSamplingAtRecall, self).__init__(
value=recall, num_thresholds=num_thresholds, name=name, dtype=dtype
)
def result(self):
recalls = tf.math.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives
)
n = self.true_negatives + self.true_positives + self.false_negatives + self.false_positives
wss = tf.math.div_no_nan(
self.true_negatives+self.false_negatives, n
)
return self._find_max_under_constraint(
recalls, wss, tf.math.greater_equal
)
def get_config(self):
"""For serialization purposes"""
config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
base_config = super(WorkSavedOverSamplingAtRecall, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
推荐阅读
- python - 如何找到数组中差值大于“n”的相邻对的数量?
- http - 没有端口的 HTTP X-Forwarded-Host 行为
- spring - 同一端点的 Spring Boot MVC 发送两个不同的响应(略有不同)
- android - 我正在制作一个像井字游戏这样的应用程序,单击按钮后我的应用程序崩溃了,我认为网格布局存在一些问题
- forms - Powershell Windows窗体边框颜色/控件?
- node.js - 节点不会安装在新项目上
- flutter - 如何找到给定2点的垂直线?
- macos - zsh:找不到命令:pub
- java - 无法识别 java-library 插件的 api 配置
- javascript - React Web // iOS 就像移动设备上的选择下拉菜单