tensorflow - 如何在 keras 中使用额外的变量自定义损失函数
问题描述
我想训练一个具有自定义损失函数的模型。损失包括两部分。Part1 和 part2 可以用 y_true(标签)和 y_predicted(实际输出)计算。但是,那loss = part1 +lambda part2
lambda 是一个变量,应该能够与网络模型的参数一起调整。在张量流中,似乎可以将 lambda 定义为 tf.Variable 以进行更新。但是,我怎样才能在 Keras 中做到这一点?
解决方案
好吧,我想出了一个解决方案。这很丑陋,但它是一个解决方案:
class UtilityLayer(Layer):
def build(self, input_shape):
self.kernel = self.add_weight(
name='kernel',
shape=(1,),
initializer='ones',
trainable=True,
constraint='nonneg'
)
super().build(input_shape)
def call(self, inputs, **kwargs):
return self.kernel
switch = -1
last_loss = 0
def custom_loss_builder(utility_layer):
def custom_loss(y_true, y_pred):
global switch, last_loss
switch *= -1
if switch == 1:
last_loss = utility_layer.trainable_weights[0] * MSE(y_true, y_pred)
return last_loss # your network loss
else:
return last_loss # your lambda loss
return custom_loss
dummy_y = np.empty(len(x))
inputs = Input(shape=(1,))
x = Dense(2, activation='relu')(inputs)
outputs = Dense(1)(x)
utility_outputs = UtilityLayer()(inputs)
model = Model(inputs, [outputs, utility_outputs])
model.compile(optimizer='adam', loss=custom_loss_builder(model.layers[-1]))
model.fit(x, [y, dummy_y], epochs=100)
以及您的 lambda 的演变:
推荐阅读
- c# - 如何比较预制件和游戏对象?
- regex - 如果之前的句子中有特定的单词,则匹配一个单词
- sapui5 - 使用 typescript 进行开发和使用 Component-preload.js 进行生产调试
- flutter - 强制 TabController 创建所有选项卡 Flutter
- ios - Parse - 查询数组指针中的值
- azure - 添加新字段后如何重新索引
- c++ - 如何修复 SDL2main.lib 中的所有这些未定义项?
- swift - 无法在 Swift 中将文件传输/删除到 ~/Library/Application Scripts/
- c# - 如何在 Moq 设置方法中返回 DbRawSqlQuery
- bash - Bash - 传递给 curl 和 grep 命令的变量不起作用