python - 通过步长减少自定义损失函数的权重
问题描述
我想随着步长的增加改变对损失施加的权重。为此,我使用了tf.keras.losses.Loss的子类。但是,其函数中的参数,如 __init__() 或 call() 似乎在计算过程中无法执行。
如何在 tf.keras.losses.Loss 的子类中获取步数?
这是我的代码。
class CategoricalCrossentropy(keras.losses.Loss):
def __init__(self, weight, name="example"):
super().__init__(name=name)
self.weight = weight
def call(self, y_true, y_pred):
weight = self.weight*np.exp(-1.0*step) #I'd like to use step number here to reduce weight.
loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
return loss
解决方案
编辑(这是因为你没有告诉函数 step 的值是什么,这将是一个局部变量,函数将无法收集它,因为它有自己的局部变量。)
我假设您正在通过迭代设置步骤。只需将其作为输入添加到调用函数中。
class CategoricalCrossentropy(keras.losses.Loss):
def __init__(self, weight, name="example"):
super().__init__(name=name)
self.weight = weight
def call(self, step, y_true, y_pred):
weight = self.weight*np.exp(-1.0*step) #I'd like to use step number here to reduce weight.
loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
return loss
keras_weight = 1
keras_example = CategoricalCrossentropy(keras_weight)
for step in range(1, max_step+1): # assuming step cannot = 0
loss = keras_example.call(1, y_true, y_perd)
如果您希望该步骤成为对象记住的内容,您可以简单地添加一个属性。
class CategoricalCrossentropy1(keras.losses.Loss):
def __init__(self, weight, name="example"):
super().__init__(name=name)
self.weight = weight
self.step = 1 #again, assuming step cannot = 0
def call(self, y_true, y_pred):
weight = self.weight*np.exp(-1.0*self.step) #I'd like to use step number here to reduce weight.
loss = -tf.reduce_sum(weight*y_true*tf.math.log(y_pred))/y_shape[1]/y_shape[2] #impose weight on CategoricalCrossentropy
self.step += 1 # add to step
return loss
希望这会有所帮助
推荐阅读
- mysql - Laravel 迁移正在运行,但无法使用 Eloquent 访问
- android - Recyclerview 未更新 searchview 文本更改
- android - 使用 react-native-swiper 时,OnScroll 不会在 android 上发送任何事件
- sql-server-2008 - 使用 cte 中的 rank() 函数删除重复记录
- sql - 有没有办法获取每个月的月底日期
- sql-server - How to check if a query returns a null value and add it to a label text?
- visual-studio-extensions - 为什么 VS2019 将我的扩展标记为已弃用
- php - 当mysql查询死亡但查询在dbms中有效时如何修复我的代码?
- php - 在数据表插件上的 mysql WHERE 子句上使用 GET 方法不起作用
- python - 在 python 3.7 中安装 pyhook “找不到满足 pyHook 要求的版本”