首页 > 解决方案 > 如何在 Keras 中为具有多个输入和输出的模型定义自定义损失函数?

问题描述

我正在研究一种架构的 Keras 实现,该架构采用 2 个输入(input_im_low、input_im_high)并将它们分别传递给一个架构并获得 2 个输出。网络使用基于这两个输出和输入定义的自定义损失函数,并且没有任何基本事实,因为训练的目标是减少自定义损失。我尝试根据 Keras 文档(https://keras.io/api/losses/)创建损失函数,但它无法正常工作。我不确定是否需要使用 add_loss() API,如果需要,应该如何针对我的情况进行调整?

谁能告诉我如何调整损失函数?

我通过运行代码得到以下错误:

    per_sample_losses = loss_fn.call(targets[i], outs[i])
IndexError: list index out of range

这是代码:

# DecomNet - low

input_im_low = Input(shape=[None, None, 3], dtype=tf.float32, name='input_im_low')
input_max_low = tf.reduce_max(input_im_low, axis=3, keepdims=True)
input_im_low_full = concatenate([input_max_low, input_im_low], axis=3)

conv = Conv2D(channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction_low")(input_im_low_full)
for idx in range(5):
    conv = Conv2D(channel, kernel_size, padding='same', activation='relu', name='activated_layer_low_%d' % idx)(conv)
conv = Conv2D(4, kernel_size, padding='same', activation=None, name='recon_layer_low')(conv)
out_low = tf.sigmoid(conv)

# DecomNet - High

input_im_high = Input(shape=[None, None, 3], dtype=tf.float32, name='input_im_high')
input_max_high = tf.reduce_max(input_im_high, axis=3, keepdims=True)
input_im_high_full = concatenate([input_max_high, input_im_high], axis=3)

conv = Conv2D(channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction_high")(input_im_high_full)
for idx in range(5):
    conv = Conv2D(channel, kernel_size, padding='same', activation='relu', name='activated_layer_high_%d' % idx)(conv)
conv = Conv2D(4, kernel_size, padding='same', activation=None, name='recon_layer_high')(conv)
out_high = tf.sigmoid(conv)

def Decom_loss(out_high, out_low):

    R_low = out_low[:, :, :, 0:3]
    I_low = out_low[:, :, :, 3:4]

    R_high = out_high[:, :, :, 0:3]
    I_high = out_high[:, :, :, 3:4]

    I_low_3 = concatenate([I_low, I_low, I_low], axis=3)
    I_high_3 = concatenate([I_high, I_high, I_high], axis=3)

    # output_R_low = R_low
    # output_I_low = I_low_3

    recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - input_im_low))
    recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - input_im_high))
    recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - input_im_low))
    recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - input_im_high))
    equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))

    Ismooth_loss_low = smooth(I_low, R_low)
    Ismooth_loss_high = smooth(I_high, R_high)

    loss_Decom = recon_loss_low + recon_loss_high + 0.001 * recon_loss_mutal_low + 0.001 * recon_loss_mutal_high + 0.1 * Ismooth_loss_low + 0.1 * Ismooth_loss_high + 0.01 * equal_R_loss

    return loss_Decom

model = Model(inputs=[input_im_low, input_im_high], outputs=[out_low, out_high])
model.compile(optimizer='adam', loss=Decom_loss)
model.fit({'input_im_low': train_low_data, 'input_im_high': train_high_data}, epochs=10, batch_size=4)

PS - 这个网络的 Tensorflow 1.x 实现可以在https://github.com/weichen582/RetinexNet找到。

标签: pythontensorflowkerasdeep-learning

解决方案


推荐阅读