首页 > 解决方案 > 自定义 Keras Metrics Class -> Metric 在某个召回值

问题描述

我正在尝试构建一个与 metrics.PrecisionAtRecall 类相当的指标。因此,我尝试通过扩展 keras.metrics.Metric 类来构建自定义指标。

原始函数是WSS = (TN + FN)/N - 1 + TP/(TP + FN),这应该在某个召回值下计算,例如 95%。

到目前为止,我所拥有的是以下内容:

class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
    super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
    self.wss = self.add_weight(name='wss', initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
    y_pred_neg = 1 - y_pred_pos
    y_pos = tf.cast(backend.round(backend.clip(y_true, 0, 1)), tf.float32)
    y_neg = 1 - y_pos
    
    fn = backend.sum(y_neg * y_pred_pos)
    tn = backend.sum(y_neg * y_pred_neg)
    tp = backend.sum(y_pos * y_pred_pos)
    n = len(y_true) # number of studies in batch
    r = tp/(tp+fn+backend.epsilon()) # recall
    self.wss.assign(((tn+fn)/n)-(1+r))

def result(self):
    return self.wss

def reset_states(self):
    # The state of the metric will be reset at the start of each epoch.
    self.wss.assign(0.)

如何计算某个召回时的 WSS?我在 tensorflow 自己的 git 存储库中看到了以下内容:

def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
if recall < 0 or recall > 1:
  raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
    value=recall,
    num_thresholds=num_thresholds,
    name=name,
    dtype=dtype)

但这通过 keras.metrics.Metric 类是不可能的

标签: tensorflowkerasmetrics

解决方案


如果我们遵循本文给出的 WSS@95 的定义:Reducing Workload in Systematic Review Preparation Using Automated Citation Classification,那么我们有

对于目前的工作,我们将召回率固定为 0.95,因此在 95% 召回率 (WSS@95%) 的抽样中节省的工作是:

WSS@95 = (TN+FN)/N - 0.05

您可以通过以下方式定义更新功能:

class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
    def __init__(self, recall, name='wss_at_recall', **kwargs):
        if recall < 0 or recall > 1:
            raise ValueError('`recall` must be in the range [0, 1].')
        self.recall = recall
        super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
        self.wss = self.add_weight(name='wss', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
        y_pred_neg = 1 - y_pred_pos
        y_neg = 1 - y_pos
    
        fn = backend.sum(y_neg * y_pred_pos)
        tn = backend.sum(y_neg * y_pred_neg)
        n = len(y_true) # number of studies in batch
        self.wss.assign(((tn+fn)/n)-(1-self.recall)) 

另一种解决方案是从 tensorflow 类扩展SensitivitySpecificityBase并在实现 PresicionAtRecall 类时实现 WSS。

通过使用此类,以下是 WSS 的计算方式:

  • 计算所有阈值的召回率(默认为 200 个阈值)。
  • 找到召回率最接近请求值的阈值索引。(在这种情况下为 0.95)。
  • 计算该索引处的 WSS。

阈值的数量用于匹配给定的召回率。

import tensorflow as tf
from tensorflow.python.keras.metrics import SensitivitySpecificityBase


class WorkSavedOverSamplingAtRecall(SensitivitySpecificityBase):
    def __init__(self, recall, num_thresholds=200, name="wss_at_recall", dtype=None):
        if recall < 0 or recall > 1:
            raise ValueError('`recall` must be in the range [0, 1].')
        self.recall = recall
        self.num_thresholds = num_thresholds
        super(WorkSavedOverSamplingAtRecall, self).__init__(
            value=recall, num_thresholds=num_thresholds, name=name, dtype=dtype
        )

    def result(self):
        recalls = tf.math.div_no_nan(
            self.true_positives, self.true_positives + self.false_negatives
        )
        n = self.true_negatives + self.true_positives + self.false_negatives + self.false_positives
        wss = tf.math.div_no_nan(
            self.true_negatives+self.false_negatives, n
        )
        return self._find_max_under_constraint(
            recalls, wss, tf.math.greater_equal
        )

    def get_config(self):
        """For serialization purposes"""
        config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
        base_config = super(WorkSavedOverSamplingAtRecall, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

推荐阅读