tensorflow - 在 tf.keras.metrics.Recall 中使用 thesholds 参数
问题描述
我想知道当指定多个阈值时如何计算召回率。以下是来自https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall的描述片段
阈值:(可选)浮点值或 [0, 1] 中浮点阈值的 python 列表/元组。将阈值与预测值进行比较以确定预测的真值(即,高于阈值为真,低于阈值为假)。为每个阈值生成一个度量值。如果未设置 thresholds 和 top_k,则默认使用 thresholds=0.5 计算召回率。
我正在尝试传递一个包含 3 个阈值的列表,并且根据描述,我预计会生成 3 个召回值(即每个阈值一次召回),但它不是以这种方式工作的,只生成 1 个召回指标。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding,Flatten,Dense
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import Dropout
from tensorflow.keras import layers
model=Sequential()
model.add(Embedding(len(tokens)+1,embedding_dim,input_length=MAX_TEXT_LEN,weights=[embedding_matrix]))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(9,activation='sigmoid'))
opt=tf.keras.optimizers.Adam(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=[tf.keras.metrics.Recall(thresholds=[0.2,0.4,0.8]))
解决方案
更新
为了能够在训练期间查看每个阈值的度量值,可以编写一个自定义回调,该回调将在每个 epoch 结束时记录每个阈值的值,例如
class CustCallback(callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print(
"The validation recall for epoch {} is {} ".format(
epoch, logs['val_recall_mult_thr']
)
)
history = model.fit(
X_train, y_train,
validation_data=(X_test, y_test),
callbacks=[CustCallback()],
batch_size=64, epochs=3, verbose=2
)
我也对文档中的“为每个阈值生成一个度量值”语言感到困惑。这不是实际发生的事情。生成的度量实际上是您在列表中指定的所有阈值的度量的算术平均值。这是一个二进制分类示例,其中对于 True Positives 和 Recall,生成一个具有 0.5 决策阈值的度量,并为提供的阈值列表生成一个度量。
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import (
callbacks,
initializers,
layers,
metrics,
optimizers,
)
from sklearn import __version__ as sk_version
from tensorflow import __version__ as tf_version
print(f"The sklearn version is {sk_version}.")
# The sklearn version is 1.0.2.
print(f"The tensorflow version is {tf_version}.")
# The tensorflow version is 2.7.0.
X1, Y1 = make_classification(
n_samples=10**6, n_features=10, n_redundant=0, n_informative=4,
n_clusters_per_class=1, n_classes=2, weights=[0.8, 0.2], random_state=42,
)
X_train, X_test, y_train, y_test = train_test_split(
X1, Y1, test_size=10**6 // 3, random_state=42
)
model = keras.Sequential()
ki = initializers.RandomNormal(mean=0.0, stddev=0.05, seed=123)
model.add(
layers.Dense(
5,
kernel_initializer=ki,
bias_initializer="zeros",
input_shape=(X_train.shape[1],),
activation="relu",
)
)
model.add(
layers.Dense(
1, activation="sigmoid"
)
)
opt = optimizers.Adam(learning_rate=0.01)
thresh_list = list(np.arange(0.05, 1, 0.05))
model.compile(
loss='binary_crossentropy',
optimizer=opt,
metrics=[
metrics.TruePositives(thresholds=0.5, name="tp_0_5"),
metrics.TruePositives(thresholds=thresh_list, name='tp_mult_thr'),
metrics.FalsePositives(name="fp"),
metrics.TrueNegatives(name="tn"),
metrics.FalseNegatives(name="fn"),
metrics.Recall(thresholds=0.5, name='recall_0_5'),
metrics.Recall(
thresholds=thresh_list,
name="recall_mult_thr",
),
],
)
model.fit(
X_train, y_train,
validation_data=(X_test, y_test),
batch_size=64, epochs=3, verbose=2
)
# Epoch 1/3
# 10417/10417 - 23s - loss: 0.0857 - tp_0_5: 123905.0000 - tp_mult_thr: 121617.1016 -
# fp: 5823.0000 - tn: 525617.0000 - fn: 11322.0000 - recall_0_5: 0.9163 -
# recall_mult_thr: 0.8994 - val_loss: 0.0809 - val_tp_0_5: 61862.0000 -
# val_tp_mult_thr: 60766.5781 - val_fp: 2210.0000 - val_tn: 263435.0000 -
# val_fn: 5826.0000 - val_recall_0_5: 0.9139 - val_recall_mult_thr: 0.8977 -
# 23s/epoch - 2ms/step
# Epoch 2/3
# 10417/10417 - 24s - loss: 0.0806 - tp_0_5: 124853.0000 - tp_mult_thr: 122652.4766 -
# fp: 5573.0000 - tn: 525867.0000 - fn: 10374.0000 - recall_0_5: 0.9233 -
# recall_mult_thr: 0.9070 - val_loss: 0.0789 - val_tp_0_5: 62789.0000 -
# val_tp_mult_thr: 61644.4219 - val_fp: 2889.0000 - val_tn: 262756.0000 -
# val_fn: 4899.0000 - val_recall_0_5: 0.9276 - val_recall_mult_thr: 0.9107 -
# 24s/epoch - 2ms/step
# Epoch 3/3
# 10417/10417 - 25s - loss: 0.0777 - tp_0_5: 125268.0000 - tp_mult_thr: 123142.8984 -
# fp: 5556.0000 - tn: 525884.0000 - fn: 9959.0000 - recall_0_5: 0.9264 -
# recall_mult_thr: 0.9106 - val_loss: 0.0781 - val_tp_0_5: 61261.0000 -
# val_tp_mult_thr: 60321.4727 - val_fp: 1638.0000 - val_tn: 264007.0000 -
# val_fn: 6427.0000 - val_recall_0_5: 0.9050 - val_recall_mult_thr: 0.8912 -
# 25s/epoch - 2ms/step
0.5 阈值 (.9050) 的验证集 Recall 符合预期:
from sklearn.metrics import recall_score
pred_probs = model.predict(X_test)
for t in thresh_list:
print(
f"threshold: {t:.2f}, "
f"recall: {recall_score(y_test, (pred_probs >= t).astype('int8'))}"
)
# threshold: 0.05, recall: 0.9748847653941615
# threshold: 0.10, recall: 0.9671876846708427
# threshold: 0.15, recall: 0.9597565299609975
# threshold: 0.20, recall: 0.9526947169365323
# threshold: 0.25, recall: 0.9446726155300792
# threshold: 0.30, recall: 0.9371971398179885
# threshold: 0.35, recall: 0.9298546271126344
# threshold: 0.40, recall: 0.9214779576882165
# threshold: 0.45, recall: 0.9133967616121026
# threshold: 0.50, recall: 0.905049639522515 <---
# threshold: 0.55, recall: 0.8952990190284836
# threshold: 0.60, recall: 0.8855483985344522
# threshold: 0.65, recall: 0.8741283536225033
# threshold: 0.70, recall: 0.8619105306701336
# threshold: 0.75, recall: 0.8483039829807352
# threshold: 0.80, recall: 0.8315654177993145
# threshold: 0.85, recall: 0.8107640940787141
# threshold: 0.90, recall: 0.7819258952842454
# threshold: 0.95, recall: 0.7366002836544143
但是多个阈值的验证集 Recall (.8912) 是所有阈值的召回平均值:
np.mean([recall_score(y_test, (pred_probs >= t).astype('int8')) for t in thresh_list])
# 0.8911693902052141
这同样适用于真阳性:
from sklearn.metrics import confusion_matrix
for t in thresh_list:
print(
f"threshold: {t:.2f}, "
f"TPs: {confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]}"
)
# threshold: 0.05, TPs: 65988
# threshold: 0.10, TPs: 65467
# threshold: 0.15, TPs: 64964
# threshold: 0.20, TPs: 64486
# threshold: 0.25, TPs: 63943
# threshold: 0.30, TPs: 63437
# threshold: 0.35, TPs: 62940
# threshold: 0.40, TPs: 62373
# threshold: 0.45, TPs: 61826
# threshold: 0.50, TPs: 61261 <---
# threshold: 0.55, TPs: 60601
# threshold: 0.60, TPs: 59941
# threshold: 0.65, TPs: 59168
# threshold: 0.70, TPs: 58341
# threshold: 0.75, TPs: 57420
# threshold: 0.80, TPs: 56287
# threshold: 0.85, TPs: 54879
# threshold: 0.90, TPs: 52927
# threshold: 0.95, TPs: 49859
和:
tp_list = list()
for t in thresh_list:
tp_list.append(
confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]
)
print(f"Avg TPs across all thresholds: {np.mean(tp_list)}")
# Avg TPs across all thresholds: 60321.47368421053
推荐阅读
- html - 如何在不同的浏览器和不同的屏幕尺寸下调整网站的缩放比例?
- node.js - nodejs 连接到 AWS RDS postgres 数据库时出错
- javascript - 带有 window.location.reload 的递归异步等待回调
- java - 删除文件后可用空间相同
- amazon-web-services - 是否存在任何 api 调用来验证访问令牌?
- ios - 如何在 iOS WKWebView 应用程序中下载 .vcf 文件以将其直接保存到 iPhone 的联系人中?
- laravel - 为什么在 Laravel 5.8 Observer 中不起作用
- html - 表内表内表丢失表体格式
- php - 如何更改实时 Laravel 站点?
- c# - 如何更具体地了解通用约束?