首页 > 解决方案 > Variational Autoencoder KL散度损失爆炸,模型返回nan

问题描述

我正在为 MRI 大脑图像(2D 切片)训练 Conv-VAE。模型的输出是sigmoid,损失函数二元交叉熵:

x = input, x_hat = output

rec_loss = nn.functional.binary_cross_entropy(x_hat.view(-1, 128 ** 2), x.view(-1, 128 ** 2),reduction='sum')

但我的问题实际上是KL散度损失:

KL_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

在训练的某个时刻,KL 散度损失非常高(某处无穷大)

在此处输入图像描述

然后我遇到了你可以在下面看到的错误,这可能是导致输出为 nan。关于如何避免这种爆炸的任何建议?

标签: deep-learningpytorchgradientautoencoder

解决方案


推荐阅读