首页 > 解决方案 > pytorch代码中的KL-divergence与公式有什么关系?

问题描述

在 VAE 教程中,两个正态分布的 kl 散度定义为: 在此处输入图像描述

而在很多代码中,比如hereherehere,代码实现如下:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

或者

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?

标签: pytorchautoencoderloss-function

解决方案


您发布的代码中的表达式假设 X 是一个不相关的多元高斯随机变量。协方差矩阵的行列式中缺少交叉项,这一点很明显。因此,均值向量和协方差矩阵采用以下形式

在此处输入图像描述

使用它,我们可以快速推导出原始表达式的组件的以下等效表示

在此处输入图像描述

将这些代回原始表达式给出

在此处输入图像描述


推荐阅读