首页 > 解决方案 > Tensorflow Github 源代码中的 Softmax 交叉熵实现

问题描述

我正在尝试在 python 中实现 Softmax 交叉熵损失。因此,我在 GitHub Tensorflow 存储库中查看了 Softmax 交叉熵损失的实现。我试图理解它,但我遇到了三个函数的循环,我不明白函数中的哪一行代码正在计算损失?

该函数softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None)返回该函数 softmax_cross_entropy_with_logits_v2_helper(labels=labels, logits=logits, axis=axis, name=name),该函数又返回softmax_cross_entropy_with_logits(precise_logits, labels, name=name)

现在函数softmax_cross_entropy_with_logits(precise_logits, labels, name=name)返回函数softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None)

这让我陷入了一个函数循环,而没有明确知道计算costfor Softmax 函数的代码在哪里。谁能指出 Softmax 交叉熵的代码在 Tensorflow GitHub 存储库中实现的位置?

我引用的 GitHub 存储库的链接在这里。它包含上述三个函数的定义。

如果代码cost需要很多难以理解的功能,你能解释一下代码行吗?谢谢。

标签: pythontensorflowbazelsoftmaxcross-entropy

解决方案


当您跟踪此函数的调用堆栈时,您最终会发现

cost, unused_backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
      precise_logits, labels, name=name)

每当您看到对gen_模块的引用时,这意味着它是 C++ 代码上自动生成的 python 包装器 - 这就是为什么您无法通过简单地查找函数并跟踪调用堆栈来找到它。

C++ 源代码可以在这里找到。

这个答案gen_nn_ops很好地描述了如何创建。


推荐阅读