首页 > 解决方案 > 在 tf.keras.metrics.Recall 中使用 thesholds 参数

问题描述

我想知道当指定多个阈值时如何计算召回率。以下是来自https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall的描述片段

阈值:(可选)浮点值或 [0, 1] 中浮点阈值的 python 列表/元组。将阈值与预测值进行比较以确定预测的真值(即,高于阈值为真,低于阈值为假)。为每个阈值生成一个度量值。如果未设置 thresholds 和 top_k,则默认使用 thresholds=0.5 计算召回率。

我正在尝试传递一个包含 3 个阈值的列表,并且根据描述,我预计会生成 3 个召回值(即每个阈值一次召回),但它不是以这种方式工作的,只生成 1 个召回指标。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding,Flatten,Dense
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import Dropout
from tensorflow.keras import layers
model=Sequential()
model.add(Embedding(len(tokens)+1,embedding_dim,input_length=MAX_TEXT_LEN,weights=[embedding_matrix]))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(9,activation='sigmoid'))
opt=tf.keras.optimizers.Adam(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=[tf.keras.metrics.Recall(thresholds=[0.2,0.4,0.8]))

标签: tensorflowkerasprecision-recall

解决方案


更新

为了能够在训练期间查看每个阈值的度量值,可以编写一个自定义回调,该回调将在每个 epoch 结束时记录每个阈值的值,例如

class CustCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(
            "The validation recall for epoch {} is {} ".format(
                epoch, logs['val_recall_mult_thr']
            )
        )

history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    callbacks=[CustCallback()],
    batch_size=64, epochs=3, verbose=2
)

我也对文档中的“为每个阈值生成一个度量值”语言感到困惑。这不是实际发生的事情。生成的度量实际上是您在列表中指定的所有阈值的度量的算术平均值。这是一个二进制分类示例,其中对于 True Positives 和 Recall,生成一个具有 0.5 决策阈值的度量,并为提供的阈值列表生成一个度量。

import numpy as np

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from tensorflow import keras
from tensorflow.keras import (
    callbacks,
    initializers,
    layers,
    metrics,
    optimizers,
)

from sklearn import __version__ as sk_version
from tensorflow import __version__ as tf_version
print(f"The sklearn version is {sk_version}.")
# The sklearn version is 1.0.2.
print(f"The tensorflow version is {tf_version}.")
# The tensorflow version is 2.7.0.

X1, Y1 = make_classification(
    n_samples=10**6, n_features=10, n_redundant=0, n_informative=4,
    n_clusters_per_class=1, n_classes=2, weights=[0.8, 0.2], random_state=42,
)

X_train, X_test, y_train, y_test = train_test_split(
    X1, Y1, test_size=10**6 // 3, random_state=42
)

model = keras.Sequential()
ki = initializers.RandomNormal(mean=0.0, stddev=0.05, seed=123)
model.add(
    layers.Dense(
        5,
        kernel_initializer=ki,
        bias_initializer="zeros",
        input_shape=(X_train.shape[1],),
        activation="relu",
    )
)
model.add(
    layers.Dense(
        1, activation="sigmoid"
    )
)

opt = optimizers.Adam(learning_rate=0.01)
thresh_list = list(np.arange(0.05, 1, 0.05))

model.compile(
    loss='binary_crossentropy',
    optimizer=opt,
    metrics=[
        metrics.TruePositives(thresholds=0.5, name="tp_0_5"),
        metrics.TruePositives(thresholds=thresh_list, name='tp_mult_thr'),
        metrics.FalsePositives(name="fp"),
        metrics.TrueNegatives(name="tn"),
        metrics.FalseNegatives(name="fn"),
        metrics.Recall(thresholds=0.5, name='recall_0_5'),
        metrics.Recall(
            thresholds=thresh_list,
            name="recall_mult_thr",
        ),
    ],
)

model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    batch_size=64, epochs=3, verbose=2
)

# Epoch 1/3
# 10417/10417 - 23s - loss: 0.0857 - tp_0_5: 123905.0000 - tp_mult_thr: 121617.1016 - 
# fp: 5823.0000 - tn: 525617.0000 - fn: 11322.0000 - recall_0_5: 0.9163 - 
# recall_mult_thr: 0.8994 - val_loss: 0.0809 - val_tp_0_5: 61862.0000 - 
# val_tp_mult_thr: 60766.5781 - val_fp: 2210.0000 - val_tn: 263435.0000 - 
# val_fn: 5826.0000 - val_recall_0_5: 0.9139 - val_recall_mult_thr: 0.8977 - 
# 23s/epoch - 2ms/step
# Epoch 2/3
# 10417/10417 - 24s - loss: 0.0806 - tp_0_5: 124853.0000 - tp_mult_thr: 122652.4766 - 
# fp: 5573.0000 - tn: 525867.0000 - fn: 10374.0000 - recall_0_5: 0.9233 - 
# recall_mult_thr: 0.9070 - val_loss: 0.0789 - val_tp_0_5: 62789.0000 - 
# val_tp_mult_thr: 61644.4219 - val_fp: 2889.0000 - val_tn: 262756.0000 - 
# val_fn: 4899.0000 - val_recall_0_5: 0.9276 - val_recall_mult_thr: 0.9107 - 
# 24s/epoch - 2ms/step
# Epoch 3/3
# 10417/10417 - 25s - loss: 0.0777 - tp_0_5: 125268.0000 - tp_mult_thr: 123142.8984 - 
# fp: 5556.0000 - tn: 525884.0000 - fn: 9959.0000 - recall_0_5: 0.9264 - 
# recall_mult_thr: 0.9106 - val_loss: 0.0781 - val_tp_0_5: 61261.0000 - 
# val_tp_mult_thr: 60321.4727 - val_fp: 1638.0000 - val_tn: 264007.0000 - 
# val_fn: 6427.0000 - val_recall_0_5: 0.9050 - val_recall_mult_thr: 0.8912 - 
# 25s/epoch - 2ms/step

0.5 阈值 (.9050) 的验证集 Recall 符合预期:

from sklearn.metrics import recall_score

pred_probs = model.predict(X_test)

for t in thresh_list:
    print(
        f"threshold: {t:.2f}, "
        f"recall: {recall_score(y_test, (pred_probs >= t).astype('int8'))}"
    )
# threshold: 0.05, recall: 0.9748847653941615
# threshold: 0.10, recall: 0.9671876846708427
# threshold: 0.15, recall: 0.9597565299609975
# threshold: 0.20, recall: 0.9526947169365323
# threshold: 0.25, recall: 0.9446726155300792
# threshold: 0.30, recall: 0.9371971398179885
# threshold: 0.35, recall: 0.9298546271126344
# threshold: 0.40, recall: 0.9214779576882165
# threshold: 0.45, recall: 0.9133967616121026
# threshold: 0.50, recall: 0.905049639522515  <---
# threshold: 0.55, recall: 0.8952990190284836
# threshold: 0.60, recall: 0.8855483985344522
# threshold: 0.65, recall: 0.8741283536225033
# threshold: 0.70, recall: 0.8619105306701336
# threshold: 0.75, recall: 0.8483039829807352
# threshold: 0.80, recall: 0.8315654177993145
# threshold: 0.85, recall: 0.8107640940787141
# threshold: 0.90, recall: 0.7819258952842454
# threshold: 0.95, recall: 0.7366002836544143

但是多个阈值的验证集 Recall (.8912) 是所有阈值的召回平均值:

np.mean([recall_score(y_test, (pred_probs >= t).astype('int8')) for t in thresh_list])
# 0.8911693902052141

这同样适用于真阳性:

from sklearn.metrics import confusion_matrix

for t in thresh_list:
    print(
        f"threshold: {t:.2f}, "
        f"TPs: {confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]}"
    )
# threshold: 0.05, TPs: 65988
# threshold: 0.10, TPs: 65467
# threshold: 0.15, TPs: 64964
# threshold: 0.20, TPs: 64486
# threshold: 0.25, TPs: 63943
# threshold: 0.30, TPs: 63437
# threshold: 0.35, TPs: 62940
# threshold: 0.40, TPs: 62373
# threshold: 0.45, TPs: 61826
# threshold: 0.50, TPs: 61261 <---
# threshold: 0.55, TPs: 60601
# threshold: 0.60, TPs: 59941
# threshold: 0.65, TPs: 59168
# threshold: 0.70, TPs: 58341
# threshold: 0.75, TPs: 57420
# threshold: 0.80, TPs: 56287
# threshold: 0.85, TPs: 54879
# threshold: 0.90, TPs: 52927
# threshold: 0.95, TPs: 49859

和:

tp_list = list()
for t in thresh_list:
    tp_list.append(
        confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]
    )

print(f"Avg TPs across all thresholds: {np.mean(tp_list)}")
# Avg TPs across all thresholds: 60321.47368421053

推荐阅读