首页 > 解决方案 > VAE 重建损失计算不正确

问题描述

我想实现一个变分自动编码器并在我的纹理数据集上训练它,该数据集由大小为 512x512 的彩色图像组成。

对于 VAE 实现,我以https://keras.io/examples/generation/vae/中的 Keras-VAE为例,并更改了层结构。训练后,我看了一些重建图像,它们真的很模糊。然后我查看了如何计算重建损失,并在这段代码中找到了它:

def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss   

由于示例 VAE 用于 mnist-dataset(图像大小为 28x28),我直观地认为我必须将重建损失更改为

reconstruction loss *= 512*512  #since my images are of the size 512x512

当我现在训练 VAE 时,total_loss 在第一个训练时期通常会飙升到 904815524.6657 之类的数字,即使在同一第一个时期的重建损失约为 15000,平均 total_loss 约为 15000。一段时间后,total_loss 和 kl_loss 有时在经过一段时间的训练后也会显示 nan。此外,reconstruction_loss * = 28 * 28 的重建图像似乎比 512*512 的重建图像更好......

我是否以正确的方式实现了损失,以及如何用高得离谱的总损失来解决这个问题?

编辑:

编码器和解码器的模型总结:

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 64) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256, 256, 64) 256         activation[0][0]                 
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 64) 36928       batch_normalization[0][0]        
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 64) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 64) 256         activation_1[0][0]               
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 128)  73856       batch_normalization_1[0][0]      
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 128)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 128)  512         activation_2[0][0]               
__________________________________________________________________________________________________
flatten (Flatten)               (None, 524288)       0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 512)          268435968   flatten[0][0]                    
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 512)          262656      dense[0][0]                      
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 512)          262656      dense[0][0]                      
__________________________________________________________________________________________________
sampling (Sampling)             (None, 512)          0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 269,074,880
Trainable params: 269,074,368
Non-trainable params: 512
__________________________________________________________________________________________________````

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 512)]             0         
_________________________________________________________________
dense_2 (Dense)              (None, 524288)            268959744 
_________________________________________________________________
reshape_1 (Reshape)          (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 128, 128, 128)     147584    
_________________________________________________________________
batch_normalization_6 (Batch (None, 128, 128, 128)     512       
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 256, 256, 64)      73792     
_________________________________________________________________
batch_normalization_7 (Batch (None, 256, 256, 64)      256       
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 512, 512, 64)      36928     
_________________________________________________________________
batch_normalization_8 (Batch (None, 512, 512, 64)      256       
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 512, 512, 3)       1731      
=================================================================
Total params: 269,220,803
Trainable params: 269,220,291
Non-trainable params: 512
_________________________________________________________________

标签: pythontensorflowkerasneural-network

解决方案


推荐阅读