python - 计算 Tensorflow 2 中每个 epoch 后每个类的召回率
问题描述
我正在尝试在使用 Tensorflow 2 的 Keras API 的模型中的每个时期之后计算每个类在二进制和多类(一个热编码)分类场景中的召回率。例如对于二进制分类,我希望能够做类似的事情
import tensorflow as tf
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(1))
model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)
history = model.fit(...)
plt.plot(history.history['binary_recall_0'])
plt.plot(history.history['binary_recall_1'])
plt.show()
或者在多类场景中,我想做类似的事情
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(3))
model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)
history = model.fit(...)
plt.plot(history.history['recall_0'])
plt.plot(history.history['recall_1'])
plt.plot(history.history['recall_2'])
plt.show()
我正在研究一个不平衡数据集的分类器,并希望能够看到我的少数类的召回率在什么时候开始下降。
我在这里https://stackoverflow.com/a/41717938/373655找到了多类分类器中特定类的精度实现。我正在尝试将其调整为我需要的内容,但keras.backend
对我来说仍然很陌生,因此将不胜感激。
我也不清楚是否可以使用 Keras metrics
(因为它们是在每批结束时计算然后取平均值)还是我需要使用 Keras callbacks
(可以在每个 epoch 结束时运行)。在我看来,它不应该对召回产生影响(例如8/10 == (3/5 + 5/5) / 2
),但这就是为什么在 Keras 2 中删除了召回,所以也许我遗漏了一些东西(https://github.com/keras-team/keras/issues /5794 )
编辑 - 部分解决方案(多类分类) @mujjiga 的解决方案适用于二元分类和多类分类,但正如 @P-Gn 指出的那样,tensorflow 2 的召回指标支持开箱即用的多类分类。例如
from tensorflow.keras.metrics import Recall
model = ...
model.compile(loss='categorical_crossentropy', metrics=[
Recall(class_id=0, name='recall_0')
Recall(class_id=1, name='recall_1')
Recall(class_id=2, name='recall_2')
])
history = model.fit(...)
plt.plot(history.history['recall_2'])
plt.plot(history.history['val_recall_2'])
plt.show()
解决方案
在 TF2 中,tf.keras.metrics.Recall
获得了class_id
能够做到这一点的成员。使用 FashionMNIST 的示例:
import tensorflow as tf
(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[..., None].astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train)
input_shape = x_train.shape[1:]
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=input_shape),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(units=10, activation='softmax')])
model.compile(loss='categorical_crossentropy', optimizer='Adam',
metrics=[tf.keras.metrics.Recall(class_id=i) for i in range(10)])
model.fit(x_train, y_train, batch_size=128, epochs=50)
在 TF 1.13 中,tf.keras.metric.Recall
没有此class_id
参数,但可以通过子类化添加(有些令人惊讶的是,在 TF2 的 alpha 版本中似乎不可能)。
class Recall(tf.keras.metrics.Recall):
def __init__(self, *, class_id, **kwargs):
super().__init__(**kwargs)
self.class_id= class_id
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = y_true[:, self.class_id]
y_pred = tf.cast(tf.equal(
tf.math.argmax(y_pred, axis=-1), self.class_id), dtype=tf.float32)
return super().update_state(y_true, y_pred, sample_weight)
推荐阅读
- node.js - axios 是否会延迟 http 请求
- jquery - 每次用户标记具有特定类的单选按钮时如何增加分数
- javascript - 无法在标签中显示新行
- microsoft-graph-api - Microsoft Graph API 将新的 ListItem 字段日期时间间隔插入到 8 小时
- xamarin.forms - Xamarin 表单中的轮播视图不响应视图中的图像按钮
- javascript - 将 `Array.prototype.includes` 传递给回调而不将其包装在异常函数中?
- javascript - React 组件正在无限次重新渲染
- python - 用链式掩码替换 numpy 数组元素
- xml - Perl XPath:使用“and”返回两个节点
- python - PDF到Python中的文本在图像文件中返回空结果