python - Keras Categorical_crossentropy 损失实现
问题描述
我正在尝试重新实现 Keras 的分类交叉熵损失,以便我可以自定义它。我得到了以下
def CustomCrossEntropy(output, target, axis=-1):
target = ops.convert_to_tensor_v2_with_dispatch(target)
output = ops.convert_to_tensor_v2_with_dispatch(output)
target.shape.assert_is_compatible_with(output.shape)
# scale preds so that the class probas of each sample sum to 1
output = output / math_ops.reduce_sum(output, axis, True)
# Compute cross entropy from probabilities.
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
return -math_ops.reduce_sum(target * math_ops.log(output), axis)
它产生的结果与让我困惑的内部函数不同,因为我到目前为止只是从github复制了代码。我在这里想念什么?
证明:
y_true = [[0., 1., 0.], [0., 0., 1.]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
customLoss = CustomCrossEntropy(y_true, y_pred)
assert loss.shape == (2,)
print(loss)
print(customLoss)
>>tf.Tensor([0.05129331 2.3025851 ], shape=(2,), dtype=float32)
>>tf.Tensor([ 0.8059049 14.506287 ], shape=(2,), dtype=float32)
解决方案
你已经在你的定义中反转了函数的参数CustomCrossEntropy
,如果你仔细检查 GitHub 中的源代码,你会发现第一个参数是target
,第二个是output
. 如果您将它们切换回来,您将获得与预期相同的结果。
import tensorflow as tf
from tensorflow.keras.backend import categorical_crossentropy as CustomCrossEntropy
y_true = [[0., 1., 0.], [0., 0., 1.]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
y_true = tf.convert_to_tensor(y_true)
y_pred = tf.convert_to_tensor(y_pred)
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
print(loss)
# tf.Tensor([0.05129331 2.3025851 ], shape=(2,), dtype=float32)
loss = CustomCrossEntropy(y_true, y_pred)
print(loss)
# tf.Tensor([0.05129331 2.3025851 ], shape=(2,), dtype=float32)
loss = CustomCrossEntropy(y_pred, y_true)
print(loss)
# tf.Tensor([ 0.8059049 14.506287 ], shape=(2,), dtype=float32)
推荐阅读
- azure - 允许来自 Azure K8S pod 的 AWS RDS 连接
- javascript - 传递 PHP 下一页变量
- expectations - Serverspec 支持期望还是我必须使用 should?
- artifactory - 上传到工件后,是否可以向多个文件添加属性?
- swift - 相同的协议对象,但不同的功能
- c# - 如何让 NLog 输出出现在 Azure 函数的流式日志中?
- swift - 基于自动换行的 Swift 拆分子字符串
- java - 使用 filter() 逐个过滤掉元素
- java - Java 可以通过哪些方式在没有 `throw` 语句的情况下抛出异常?
- c++ - setContextProperty() 在这种情况下如何失败?