首页 > 解决方案 > 多标签分类损失函数

问题描述

我在很多地方看到,对于使用神经网络的多标签分类,一个有用的损失函数是每个输出节点的二进制交叉熵。

在 TensorFlow 中,它看起来像这样:

cost = tf.nn.sigmoid_cross_entropy_with_logits()

这给出了一个具有与我们拥有的输出节点一样多的值的数组。

我的问题是,这个成本函数是否应该在输出节点的数量上取平均值?在 Tensorflow 中看起来像:

cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits())

还是每个损失都单独处理?

谢谢

标签: pythontensorflowneural-networkclassificationmultilabel-classification

解决方案


对于N多标签分类中的标签,是否对每个类的损失求和,或者是否使用计算平均损失并不重要tf.reduce_mean:梯度将指向同一方向。

但是,如果将总和除以N(这就是平均化的本质),这将影响一天结束时的学习率。如果你不确定多标签分类任务中会有多少标签,它可能更容易使用tf.reduce_mean,因为你不必重新调整这个损失组件与损失的其他组件相比的权重,并且你不必在N标签更改数量中调整学习率。


推荐阅读