python - 如何通过 Tensorflow2.x 中的子类 tf.keras.losses.Loss 类自定义损失
问题描述
当我阅读 Tensorflow 网站上的指南时,我发现了两种自定义损失的方法。第一个是定义一个损失函数,就像:
def basic_loss_function(y_true, y_pred):
return tf.math.reduce_mean(tf.abs(y_true - y_pred))
并且为了简单起见,我们假设batch size也是1,所以和的形状y_true
都是y_pred
(1, c),其中c是类的数量。所以在这个方法中,我们给出两个向量y_true
和y_pred
,并返回一个值(scala)。
然后,第二种方法是子tf.keras.losses.Loss
类化,guide中的代码是:
class WeightedBinaryCrossEntropy(keras.losses.Loss):
"""
Args:
pos_weight: Scalar to affect the positive labels of the loss function.
weight: Scalar to affect the entirety of the loss function.
from_logits: Whether to compute loss from logits or the probability.
reduction: Type of tf.keras.losses.Reduction to apply to loss.
name: Name of the loss function.
"""
def __init__(self, pos_weight, weight, from_logits=False,
reduction=keras.losses.Reduction.AUTO,
name='weighted_binary_crossentropy'):
super().__init__(reduction=reduction, name=name)
self.pos_weight = pos_weight
self.weight = weight
self.from_logits = from_logits
def call(self, y_true, y_pred):
ce = tf.losses.binary_crossentropy(
y_true, y_pred, from_logits=self.from_logits)[:,None]
ce = self.weight * (ce*(1-y_true) + self.pos_weight*ce*(y_true))
return ce
在调用方法中,像往常一样,我们给出了两个向量y_true
和y_pred
,但我注意到它返回ce
,这是一个形状为 (1, c) 的向量!
那么上面的玩具例子有什么问题吗?或者 Tensorflow2.x 背后有什么魔力?
解决方案
除了实现之外,两者之间的主要区别在于损失函数的类型。第一个是 L1 损失(定义的绝对差异的平均值,主要用于类似回归的问题),而第二个是二元交叉熵(用于分类)。它们并不意味着相同损失的不同实现,这在您链接的指南中有所说明。
多标签、多类分类设置中的二元交叉熵为每个类输出一个值,就好像它们彼此独立一样。
编辑:
在第二个损失函数中,reduction
参数控制输出聚合的方式,例如。取元素的总和或对批次求和等。默认情况下,您的代码使用,如果您检查源代码keras.losses.Reduction.AUTO
,则转换为对批次求和。这意味着,最终损失将是一个向量,但还有其他可用的缩减,您可以在docs中查看它们。我相信即使你没有定义减少来获取损失向量中损失元素的总和,TF 优化器也会这样做,以避免反向传播向量的错误。向量上的反向传播会导致“有助于”每个损失元素的权重出现问题。但是,我没有在源代码中检查这一点。:)
推荐阅读
- python - 复制然后粘贴在python终端mac中产生垃圾
- javascript - 从表中获取数据以显示为纯文本
- php - 有没有办法连接 1 行唯一键并在表单中显示其数据?
- angular - 如何在带有 Firebase 的 Angular 7 中使用带有反应形式的图像上传
- java - 使用 java 访问 AWS 根用户帐户
- java - 通过 selenium 在 IE 浏览器中打开一个新窗口,没有 403 错误
- jira - 作为我们项目的一部分,我们希望捕获 JIRA 中可用的信息,以便我们冲刺到数据库中”
- swift - 在单击集合视图之前,垂直 NSTableView 中的水平 NSCollectionView 不会选择
- php - 核心 PHP,从 PDF 文件中隐藏一些文本
- r - 在数据框中以交替方式将模式添加到每一行