首页 > 解决方案 > Keras 中多个输入的多个损失函数

问题描述

我的网络有 2 个输入,一个 3D 灰度体积和一个 2D 彩色图像。最后,我有一个 2D 输出,与输入图像的形状相同。

我的问题:我想结合 2 个损失函数,一个用于网络的每个输入分支。一个损失函数应该将网络的预测与 2D 灰度地面实况进行比较,另一个与 2D 颜色输入进行比较。这些基本事实存储在单独的文件中。

我不明白 y_true 是如何工作的,以及如何告诉它查看我的 groundtruth 文件。我认为我的问题主要是语法之一,在查看 SO 上的其他帖子后,我更加困惑。

这是我想出的代码。它当然不起作用,但它应该让我知道我的目标是什么。

def custom_loss(groundtruth_grayscale, groundtruth_colour):
    def loss(y_true, y_pred):
        loss_grayscale = ssim(y_pred, groundtruth_grayscale)
        loss_colour = ssim(y_pred, groundtruth_colour)
        ssim_loss = loss_grayscale + loss_colour

        l1_loss_grayscale = l1(y_pred, groundtruth_grayscale)
        l1_loss_colour = l1(y_pred, groundtruth_colour)
        l1_loss = l1_loss_grayscale + l1_loss_colour

        return ssim_loss + l1_loss
    return loss

# images_groundtruth_grayscale is a variable containing all groundtruth_grayscale images
model_combined.optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002).minimize(custom_loss(images_groundtruth_grayscale, images_groundtruth_colour), var_list = model_combined.trainable_variables)

如果有帮助,模型的摘要:

model_combined.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, None, None, 1 9728        input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, None, 1 512         conv2d[0][0]                     
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, None, None, N 16128       input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 6 204864      batch_normalization[0][0]        
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, N 512         conv3d[0][0]                     
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, None, None, N 1024064     batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 3 18464       batch_normalization_1[0][0]      
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, N 256         conv3d_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 3 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, None, None, N 55328       batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 1 4624        batch_normalization_2[0][0]      
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, N 128         conv3d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 1 64          conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv3d_3 (Conv3D)               (None, None, None, N 13840       batch_normalization_6[0][0]      
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, None, None,  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, N 64          conv3d_3[0][0]                   
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, None, None,  0           tf_op_layer_ExpandDims[0][0]     
                                                                 batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv3d_4 (Conv3D)               (None, None, None, N 272         tf_op_layer_add[0][0]            
==================================================================================================
Total params: 1,349,232
Trainable params: 1,348,272
Non-trainable params: 960
__________________________________________________________________________________________________

标签: pythontensorflowkeras

解决方案


推荐阅读