python - 精度类型
问题描述
精度是通过使用 keras 库获得的:
model.compile(optimizer='sgd',
loss='mse',
metrics=[tf.keras.metrics.Precision()])
sklearn计算的精度和keras计算的精度是什么类型的?
precision_score(y_true, y_pred, average=???)
- 宏
- 微
- 加权
- 没有任何
如下所示将 zero_division 设置为 1 时会发生什么?:
precision_score(y_true, y_pred, average=None, zero_division=1)
解决方案
TLDR;默认binary
为二分类和micro
多分类。其他平均类型,例如None
和macro
也可以通过如下所述的微小修改来实现。
tf.keras.Precision()
这应该让您清楚地了解和之间的区别sklearn.metrics.precision_score()
。让我们比较不同的场景。
场景一:二分类
对于二进制分类,您的 y_true 和 y_pred 分别为 0,1 和 0-1。两者的实现都非常简单。
Sklearn 文档:仅报告 pos_label 指定的类的结果。这仅适用于目标 (y_{true,pred}) 是二进制的。
#Binary classification
from sklearn.metrics import precision_score
import tensorflow as tf
y_true = [0,1,1,1]
y_pred = [1,0,1,1]
print('sklearn precision: ',precision_score(y_true, y_pred, average='binary'))
#Only report results for the class specified by pos_label.
#This is applicable only if targets (y_{true,pred}) are binary.
m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision: 0.6666666666666666
tf.keras precision: 0.6666667
场景二:多类分类(全局精度)
在这里,您正在使用多类标签,但您不必担心每个单独类的精度如何。您只需要一组全局 TP 和 FP 来计算总精度分数。in sklearn
this 由参数设置micro
,而 in tf.keras
this 是默认设置Precision()
Sklearn 文档:通过计算总的真阳性、假阴性和假阳性来全局计算指标。
#Multi-class classification (global precision)
#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]
print('sklearn precision: ',precision_score(y_true, y_pred, average='micro'))
#Calculate metrics globally by counting the total true positives, false negatives and false positives.
m.reset_states()
m = tf.keras.metrics.Precision()
m.update_state(y_true, y_pred)
print('tf.keras precision:',m.result().numpy())
sklearn precision: 0.3333333333333333
tf.keras precision: 0.33333334
场景 3:多类分类(每个标签的二进制精度)
如果您想知道每个单独类的精度,您会对这种情况感兴趣。sklearn
这是通过将average
参数设置为 来完成的None
,而在tf.keras
您必须分别使用 为每个单独的类实例化对象class_id
。
Sklearn 文档:如果没有,则返回每个班级的分数。
#Multi-class classification (binary precision for each label)
#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]
print('sklearn precision: ',precision_score(y_true, y_pred, average=None))
#If None, the scores for each class are returned.
#For class 0
m0 = tf.keras.metrics.Precision(class_id=0)
m0.update_state(y_true, y_pred)
#For class 1
m1 = tf.keras.metrics.Precision(class_id=1)
m1.update_state(y_true, y_pred)
#For class 2
m2 = tf.keras.metrics.Precision(class_id=2)
m2.update_state(y_true, y_pred)
mm = [m0.result().numpy(), m1.result().numpy(), m2.result().numpy()]
print('tf.keras precision:',mm)
sklearn precision: [0.66666667 0. 0. ]
tf.keras precision: [0.6666667, 0.0, 0.0]
场景 4:多类分类(单个二进制分数的平均值)
一旦你计算了每个类的单独精度,你可能想要取平均分(或加权平均值)。在中,通过将参数设置为来sklearn
获取单个分数的简单平均值。您可以通过取上述场景中计算的各个精度的平均值来获得相同的结果。average
macro
tf.keras
Sklearn 文档:计算每个标签的指标,并找到它们的未加权平均值。
#Multi-class classification (Average of individual binary scores)
#3 classes, 6 samples
y_true = [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0],[0,0,1]]
y_pred = [[1,0,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0],[0,1,0]]
print('sklearn precision (Macro): ',precision_score(y_true, y_pred, average='macro'))
print('sklearn precision (Avg of None):' ,np.average(precision_score(y_true, y_pred, average=None)))
print(' ')
print('tf.keras precision:',np.average(mm)) #mm is list of individual precision scores
sklearn precision (Macro): 0.2222222222222222
sklearn precision (Avg of None): 0.2222222222222222
tf.keras precision: 0.22222222
注意:请记住,使用sklearn
,您有直接预测标签的模型,并且precision_score
是一种独立的方法。因此,它可以直接对预测和实际的标签列表进行操作。然而,tf.keras.Precision()
这是一个必须应用于二进制或多类密集输出的度量。它将无法直接使用标签。您必须为每个样本提供一个 n 长度的数组,其中 n 是类/输出密集节点的数量。
希望这可以阐明两者在各种情况下的不同之处。请在sklearn 文档和tf.keras 文档中找到更多详细信息。
你的第二个问题——
根据 sklearn 文档,
zero_division - “warn”, 0 or 1, default=”warn”
#Sets the value to return when there is a zero division. If set to “warn”, #this acts as 0, but warnings are also raised.
这是一个异常处理标志。在计算分数的过程中,如果有时间遇到 a divide by zero
,它会认为它等于 0 并发出警告。否则,如果明确设置为 1,则将其设置为 1。
推荐阅读
- javascript - 为什么 Jquery 找不到我的 json_encoded() 嵌套数组?
- html - 使用 css 文本部分溢出的图像
- c# - 如何使全屏切换按钮开/关
- angular - 当我将引导程序放入styles.css Angular时如何删除警告?
- reactjs - 在对列进行排序后检查或选择行时,HTML 表会返回
- xamarin.forms - 如何在 xamarin 表单中为 NavigationPage 类或导航后退按钮设置或实现automationId 属性?
- bash - 我有一个大的多行文本文件,将行内容读入 bash 变量,直到出现特定字符(终止符)
- webpack - 使 webpack-dev-server 重用现有打开的选项卡
- python - 比较python中的两个excel文件
- amazon-s3 - 无法在 AWS S3 之上从 presto 创建架构/表