首页 > 解决方案 > 我应该如何实现计算整个预测的 tf.keras.Metric?

问题描述

tf.keras.Metric接口提供了一个有用的工具来实现附加指标,例如损失/准确性。该接口旨在在update_state(self, y_pred, y_true)调用时更新批次,结果应在result(self). 然而,在实施 FID、Inception Score、Expected Calibration Error 等指标时,我们必须查看整个预测集,而不是单个样本的迭代视图。我应该如何在 Tensorflow 中实现这样的自定义指标?我应该使用其他 API 吗?

标签: pythontensorflowkeras

解决方案


据我了解你的问题;

class BinaryTruePositives(tf.keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super(BinaryTruePositives, self).__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = tf.cast(sample_weight, self.dtype)
      sample_weight = tf.broadcast_to(sample_weight, values.shape)
      values = tf.multiply(values, sample_weight)
    self.true_positives.assign_add(tf.reduce_sum(values))

  def result(self):
    return self.true_positives

[来自tf2 docs的示例]
update_state()我们实现度量计算功能的方法中。本质上,update_state()为整个小批量返回一个。因此,在 tf2 中可以进行小批量评估。但是,为了支持批量评估指标,您还必须更新整个输入管道。


推荐阅读