首页 > 解决方案 > 使用 metrics=['accuracy'] 时,Keras 中使用什么精度函数?

问题描述

我想知道使用时使用了什么精度函数metrics=['accuracy']

model.compile(loss='binary_crossentropy', metrics=['accuracy'])

我看到了keras/metrics.py文件。该文件中有some_accuracy函数。但我找不到accuracy。我在哪里可以看到源代码以及它是如何工作的?

标签: pythonmachine-learningkeras

解决方案


从您使用的损失函数中自动推断出适当的准确度函数。正如您所提到的,精度函数已在metric.py文件中定义:

def binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)


def categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.argmax(y_true, axis=-1),
                          K.argmax(y_pred, axis=-1)),
                  K.floatx())


def sparse_categorical_accuracy(y_true, y_pred):
    # flatten y_true in case it's in shape (num_samples, 1) instead of (num_samples,)
    return K.cast(K.equal(K.flatten(y_true),
                          K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
                          K.floatx())
  • binary_accuracy当模型的最后一个输出轴的维度等于一(即(..., 1))或binary_crossentropy用作损失时使用。

  • sparse_categorical_accuracysparse_categorical_crossentropy用作损失函数时使用。

  • 最后,categorical_accuracycategorical_crossentropy已设置为损失函数时使用。

另外,请注意,该accuracy指标仅对分类任务有效。因此,如果您accuarcy在回归任务中用作指标,则报告的指标值可能根本无效。

此外,还有另外两个内置的精度函数:

def top_k_categorical_accuracy(y_true, y_pred, k=5):
    return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)


def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
    return K.mean(K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)

要使用它们,您需要将它们的名称显式传递给metrics参数,即metrics=['top_k_categorical_accuracy']metrics=['sparse_top_k_categorical_accuracy']


推荐阅读