python - Keras 中多个输入的多个损失函数
问题描述
我的网络有 2 个输入,一个 3D 灰度体积和一个 2D 彩色图像。最后,我有一个 2D 输出,与输入图像的形状相同。
我的问题:我想结合 2 个损失函数,一个用于网络的每个输入分支。一个损失函数应该将网络的预测与 2D 灰度地面实况进行比较,另一个与 2D 颜色输入进行比较。这些基本事实存储在单独的文件中。
我不明白 y_true 是如何工作的,以及如何告诉它查看我的 groundtruth 文件。我认为我的问题主要是语法之一,在查看 SO 上的其他帖子后,我更加困惑。
这是我想出的代码。它当然不起作用,但它应该让我知道我的目标是什么。
def custom_loss(groundtruth_grayscale, groundtruth_colour):
def loss(y_true, y_pred):
loss_grayscale = ssim(y_pred, groundtruth_grayscale)
loss_colour = ssim(y_pred, groundtruth_colour)
ssim_loss = loss_grayscale + loss_colour
l1_loss_grayscale = l1(y_pred, groundtruth_grayscale)
l1_loss_colour = l1(y_pred, groundtruth_colour)
l1_loss = l1_loss_grayscale + l1_loss_colour
return ssim_loss + l1_loss
return loss
# images_groundtruth_grayscale is a variable containing all groundtruth_grayscale images
model_combined.optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002).minimize(custom_loss(images_groundtruth_grayscale, images_groundtruth_colour), var_list = model_combined.trainable_variables)
如果有帮助,模型的摘要:
model_combined.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, None, None, 1 9728 input_1[0][0]
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, None, 1 512 conv2d[0][0]
__________________________________________________________________________________________________
conv3d (Conv3D) (None, None, None, N 16128 input_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 6 204864 batch_normalization[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, N 512 conv3d[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256 conv2d_1[0][0]
__________________________________________________________________________________________________
conv3d_1 (Conv3D) (None, None, None, N 1024064 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 3 18464 batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, N 256 conv3d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 3 128 conv2d_2[0][0]
__________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, None, None, N 55328 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 1 4624 batch_normalization_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, N 128 conv3d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 1 64 conv2d_3[0][0]
__________________________________________________________________________________________________
conv3d_3 (Conv3D) (None, None, None, N 13840 batch_normalization_6[0][0]
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, None, None, 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, N 64 conv3d_3[0][0]
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, None, None, 0 tf_op_layer_ExpandDims[0][0]
batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv3d_4 (Conv3D) (None, None, None, N 272 tf_op_layer_add[0][0]
==================================================================================================
Total params: 1,349,232
Trainable params: 1,348,272
Non-trainable params: 960
__________________________________________________________________________________________________
解决方案
推荐阅读
- angular - 解析 Ionic (Angular) 期间的 Http 失败
- r - 如何使用 polychor() 在输入矩阵中获取所有成对比较?
- javascript - javascript 检查数组的值是否作为对象中的键存在
- reactjs - api调用后是否可以立即获得价值?
- sql - 在 SQL Server 中使用 SUM 函数时如何显示空值
- css - 编译 Sass 时出错:$spacers Bootstrap 处的 ModuleBuildError
- vue.js - Vuetify:如何在输入中出现 v-text-field 标签
- javascript - 发送发送时UseEffect不更新反应网站数据
- wordpress - 如何使特定页面显示小部件区域
- java - 如何比较 2 个对象的年龄值?