首页 > 解决方案 > 精度类型

问题描述

精度是通过使用 keras 库获得的:

model.compile(optimizer='sgd',
          loss='mse',
          metrics=[tf.keras.metrics.Precision()])

sklearn计算的精度和keras计算的精度是什么类型的?

precision_score(y_true, y_pred, average=???)
  1. 加权
  2. 没有任何

如下所示将 zero_division 设置为 1 时会发生什么?:

precision_score(y_true, y_pred, average=None, zero_division=1)

标签: pythontensorflowscikit-learnprecision

解决方案


TLDR;默认binary为二分类和micro多分类。其他平均类型,例如Nonemacro也可以通过如下所述的微小修改来实现。


tf.keras.Precision()这应该让您清楚地了解和之间的区别sklearn.metrics.precision_score()。让我们比较不同的场景。

场景一:二分类

对于二进制分类,您的 y_true 和 y_pred 分别为 0,1 和 0-1。两者的实现都非常简单。

Sklearn 文档:仅报告 pos_label 指定的类的结果。这仅适用于目标 (y_{true,pred}) 是二进制的。

#Binary classification

from sklearn.metrics import precision_score
import tensorflow as tf

y_true = [0,1,1,1]
y_pred = [1,0,1,1]

print('sklearn precision: ',precision_score(y_true, y_pred, average='binary'))
#Only report results for the class specified by pos_label. 
#This is applicable only if targets (y_{true,pred}) are binary.

m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision:  0.6666666666666666
tf.keras precision: 0.6666667

场景二:多类分类(全局精度)

在这里,您正在使用多类标签,但您不必担心每个单独类的精度如何。您只需要一组全局 TP 和 FP 来计算总精度分数。in sklearnthis 由参数设置micro,而 in tf.kerasthis 是默认设置Precision()

Sklearn 文档:通过计算总的真阳性、假阴性和假阳性来全局计算指标。

#Multi-class classification (global precision)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision: ',precision_score(y_true, y_pred, average='micro'))
#Calculate metrics globally by counting the total true positives, false negatives and false positives.

m.reset_states()
m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision:  0.3333333333333333
tf.keras precision: 0.33333334

场景 3:多类分类(每个标签的二进制精度)

如果您想知道每个单独类的精度,您会对这种情况感兴趣。sklearn这是通过将average参数设置为 来完成的None,而在tf.keras您必须分别使用 为每个单独的类实例化对象class_id

Sklearn 文档:如果没有,则返回每个班级的分数。

#Multi-class classification (binary precision for each label)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision: ',precision_score(y_true, y_pred, average=None))
#If None, the scores for each class are returned.

#For class 0
m0 = tf.keras.metrics.Precision(class_id=0)
m0.update_state(y_true, y_pred)

#For class 1
m1 = tf.keras.metrics.Precision(class_id=1)
m1.update_state(y_true, y_pred)

#For class 2
m2 = tf.keras.metrics.Precision(class_id=2)
m2.update_state(y_true, y_pred)

mm = [m0.result().numpy(), m1.result().numpy(), m2.result().numpy()]

print('tf.keras precision:',mm)
sklearn precision:  [0.66666667 0.         0.        ]
tf.keras precision: [0.6666667, 0.0, 0.0]

场景 4:多类分类(单个二进制分数的平均值)

一旦你计算了每个类的单独精度,你可能想要取平均分(或加权平均值)。在中,通过将参数设置为来sklearn获取单个分数的简单平均值。您可以通过取上述场景中计算的各个精度的平均值来获得相同的结果。averagemacrotf.keras

Sklearn 文档:计算每个标签的指标,并找到它们的未加权平均值。

#Multi-class classification (Average of individual binary scores)

#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]

print('sklearn precision (Macro): ',precision_score(y_true, y_pred, average='macro'))
print('sklearn precision (Avg of None):' ,np.average(precision_score(y_true, y_pred, average=None)))

print(' ')

print('tf.keras precision:',np.average(mm)) #mm is list of individual precision scores
sklearn precision (Macro):  0.2222222222222222
sklearn precision (Avg of None):  0.2222222222222222
 
tf.keras precision: 0.22222222

注意:请记住,使用sklearn,您有直接预测标签的模型,并且precision_score是一种独立的方法。因此,它可以直接对预测和实际的标签列表进行操作。然而,tf.keras.Precision()这是一个必须应用于二进制或多类密集输出的度量。它将无法直接使用标签。您必须为每个样本提供一个 n 长度的数组,其中 n 是类/输出密集节点的数量。

希望这可以阐明两者在各种情况下的不同之处。请在sklearn 文档tf.keras 文档中找到更多详细信息。


你的第二个问题——

根据 sklearn 文档,

zero_division - “warn”, 0 or 1, default=”warn”
#Sets the value to return when there is a zero division. If set to “warn”, #this acts as 0, but warnings are also raised.

这是一个异常处理标志。在计算分数的过程中,如果有时间遇到 a divide by zero,它会认为它等于 0 并发出警告。否则,如果明确设置为 1,则将其设置为 1。


推荐阅读