首页 > 解决方案 > 通过步长减少自定义损失函数的权重

问题描述

我想随着步长的增加改变对损失施加的权重。为此,我使用了tf.keras.losses.Loss的子类。但是,其函数中的参数,如 __init__() 或 call() 似乎在计算过程中无法执行。

如何在 tf.keras.losses.Loss 的子类中获取步数?

这是我的代码。

class CategoricalCrossentropy(keras.losses.Loss):
    def __init__(self, weight, name="example"):
        super().__init__(name=name)
        self.weight = weight

    def call(self, y_true, y_pred):

        weight = self.weight*np.exp(-1.0*step) #I'd like to use step number here to reduce weight.
        loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
        
        return loss

标签: pythontensorflowmachine-learningkerastensorflow2.0

解决方案


编辑(这是因为你没有告诉函数 step 的值是什么,这将是一个局部变量,函数将无法收集它,因为它有自己的局部变量。)

我假设您正在通过迭代设置步骤。只需将其作为输入添加到调用函数中。

class CategoricalCrossentropy(keras.losses.Loss):
    def __init__(self, weight, name="example"):
        super().__init__(name=name)
        self.weight = weight

    def call(self, step, y_true, y_pred):

        weight = self.weight*np.exp(-1.0*step) #I'd like to use step number here to reduce weight.
        loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
        
        return loss

keras_weight = 1
keras_example = CategoricalCrossentropy(keras_weight)

for step in range(1, max_step+1):     # assuming step cannot = 0
    loss = keras_example.call(1, y_true, y_perd)

如果您希望该步骤成为对象记住的内容,您可以简单地添加一个属性。

class CategoricalCrossentropy1(keras.losses.Loss):
    def __init__(self, weight, name="example"):
        super().__init__(name=name)
        self.weight = weight
        self.step = 1           #again, assuming step cannot = 0

    def call(self, y_true, y_pred):
        weight = self.weight*np.exp(-1.0*self.step) #I'd like to use step number here to reduce weight.
        loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
        
        self.step += 1  # add to step

        return loss

希望这会有所帮助


推荐阅读