python - 如何在 Keras 中为具有多个输入和输出的模型定义自定义损失函数?
问题描述
我正在研究一种架构的 Keras 实现,该架构采用 2 个输入(input_im_low、input_im_high)并将它们分别传递给一个架构并获得 2 个输出。网络使用基于这两个输出和输入定义的自定义损失函数,并且没有任何基本事实,因为训练的目标是减少自定义损失。我尝试根据 Keras 文档(https://keras.io/api/losses/)创建损失函数,但它无法正常工作。我不确定是否需要使用 add_loss() API,如果需要,应该如何针对我的情况进行调整?
谁能告诉我如何调整损失函数?
我通过运行代码得到以下错误:
per_sample_losses = loss_fn.call(targets[i], outs[i])
IndexError: list index out of range
这是代码:
# DecomNet - low
input_im_low = Input(shape=[None, None, 3], dtype=tf.float32, name='input_im_low')
input_max_low = tf.reduce_max(input_im_low, axis=3, keepdims=True)
input_im_low_full = concatenate([input_max_low, input_im_low], axis=3)
conv = Conv2D(channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction_low")(input_im_low_full)
for idx in range(5):
conv = Conv2D(channel, kernel_size, padding='same', activation='relu', name='activated_layer_low_%d' % idx)(conv)
conv = Conv2D(4, kernel_size, padding='same', activation=None, name='recon_layer_low')(conv)
out_low = tf.sigmoid(conv)
# DecomNet - High
input_im_high = Input(shape=[None, None, 3], dtype=tf.float32, name='input_im_high')
input_max_high = tf.reduce_max(input_im_high, axis=3, keepdims=True)
input_im_high_full = concatenate([input_max_high, input_im_high], axis=3)
conv = Conv2D(channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction_high")(input_im_high_full)
for idx in range(5):
conv = Conv2D(channel, kernel_size, padding='same', activation='relu', name='activated_layer_high_%d' % idx)(conv)
conv = Conv2D(4, kernel_size, padding='same', activation=None, name='recon_layer_high')(conv)
out_high = tf.sigmoid(conv)
def Decom_loss(out_high, out_low):
R_low = out_low[:, :, :, 0:3]
I_low = out_low[:, :, :, 3:4]
R_high = out_high[:, :, :, 0:3]
I_high = out_high[:, :, :, 3:4]
I_low_3 = concatenate([I_low, I_low, I_low], axis=3)
I_high_3 = concatenate([I_high, I_high, I_high], axis=3)
# output_R_low = R_low
# output_I_low = I_low_3
recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - input_im_low))
recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - input_im_high))
recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - input_im_low))
recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - input_im_high))
equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
Ismooth_loss_low = smooth(I_low, R_low)
Ismooth_loss_high = smooth(I_high, R_high)
loss_Decom = recon_loss_low + recon_loss_high + 0.001 * recon_loss_mutal_low + 0.001 * recon_loss_mutal_high + 0.1 * Ismooth_loss_low + 0.1 * Ismooth_loss_high + 0.01 * equal_R_loss
return loss_Decom
model = Model(inputs=[input_im_low, input_im_high], outputs=[out_low, out_high])
model.compile(optimizer='adam', loss=Decom_loss)
model.fit({'input_im_low': train_low_data, 'input_im_high': train_high_data}, epochs=10, batch_size=4)
PS - 这个网络的 Tensorflow 1.x 实现可以在https://github.com/weichen582/RetinexNet找到。
解决方案
推荐阅读
- c# - 条目上的多种行为(Xamarin Forms)
- c++ - C++ 用不重复的随机数填充数组
- flutter - 如何修复这个 Flutter xcode 构建?
- javascript - 从 DB 解析 JSON 并在前端显示
- javascript - 有条件地更改 Javascript 对象中键的顺序
- javascript - Javascript 在没有警报弹出窗口的情况下获取月/年
- bash - 如何在 Bash 中计算整数列表的总和
- nativescript - 使用谷歌提供商登录 Nativescript 应用程序的问题
- git - 如何显示实际上导致两个分支之间差异的 git 提交列表?
- javascript - 在浏览器关闭时为 React 明确清除本地存储