python - 如何在 Tensorflow 2.0 中获取其他指标(不仅是准确性)?
问题描述
我是 Tensorflow 领域的新手,我正在研究 mnist 数据集分类的简单示例。我想知道除了准确性和损失(并可能显示它们)之外,我如何获得其他指标(例如精度、召回率等)。这是我的代码:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
import os
#load mnist dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#create and compile the model
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#model checkpoint (only if there is an improvement)
checkpoint_path = "logs/weights-improvement-{epoch:02d}-{accuracy:.2f}.hdf5"
cp_callback = ModelCheckpoint(checkpoint_path, monitor='accuracy',save_best_only=True,verbose=1, mode='max')
#Tensorboard
NAME = "tensorboard_{}".format(int(time.time())) #name of the model with timestamp
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))
#train the model
model.fit(x_train, y_train, callbacks = [cp_callback, tensorboard], epochs=5)
#evaluate the model
model.evaluate(x_test, y_test, verbose=2)
由于我只得到准确度和损失,我怎样才能得到其他指标?提前谢谢你,如果这是一个简单的问题,或者如果已经在某个地方得到回答,我很抱歉。
解决方案
从 TensorFlow 2.X 开始,precision
都recall
可以作为内置指标使用。
因此,您不需要手动实现它们。除此之外,它们之前在 Keras 2.X 版本中被删除,因为它们具有误导性——因为它们是以批量方式计算的,精度和召回率的全局(真实)值实际上是不同的。
你可以在这里看看:https ://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
现在他们有一个内置的累加器,可以确保正确计算这些指标。
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy',tf.keras.metrics.Precision(),tf.keras.metrics.Recall()])
推荐阅读
- tabs - C++、MFC MDI、激活特定选项卡
- php - 在linux上使用php检查服务状态?
- css - 如何使用 Laravel Mix 将版本号添加到路径中?
- tesseract - 为什么要使用 RecursiveParserWrapper 而不是 Parser 来从图像中提取文本?
- android - 无法使用 firebase-config:16.0.0 和 firebase-core:16.0.1 构建
- wordpress - 如何将类别的默认值添加到 wordpress 中的自定义帖子?
- html - 从输入到其他组件的Angular6数据
- css - 在 CSS 中使用标题中的特殊格式缩进
- css - 为 formGroup 禁用 ng-invalid
- apache-spark - “SHOW TABLES LIKE '*sub_string*'” 不适用于 HIVECONTEXT