首页 > 解决方案 > 如何使用 TensorFlow 2.0 计算分类交叉熵?

问题描述

我试图了解CategoricalCrossentropy()TensorFlow 2.0 中的损失函数。当我使用

tf.keras.metrics.CategoricalCrossentropy(actual, pred)

我收到以下错误:

ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

谁能解释如何纠正?

标签: pythonnumpytensorflowtensorflow2.0

解决方案


docs计算交叉熵,使用,

cce = tf.keras.metrics.CategoricalCrossentropy()
# cce.update_state(target, prediction)
cce.update_state([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

cce.result().numpy()
# 1.1769392

OTOH,如果您尝试计算交叉熵损失,请tf.keras.losses.CategoricalCrossentropy改用:

cce = tf.keras.metrics.CategoricalCrossentropy()
cce([[0, 1, 0], [0, 0, 1]], [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]).numpy()
# 1.1769392

推荐阅读