首页 > 解决方案 > Pytorch 自定义/修改的 GAN 损失等价物

问题描述

我正在尝试实现一个 Pytorch 版本的Creative Adversarial Networks,这是一个具有修改/自定义损失函数的 GAN。

这是损失函数的公式。我将 Pytorchnn.CrossEntropyLoss用于鉴别器的修改损失函数,它似乎正在工作,因为它的损失随着时代的推移而减少,但我认为nn.CrossEntropyLoss它不适合生成器,正如nn.CrossEntropyLoss预期的那样Long,而不是Float张量,以及论文的损失函数,特别是生成器的损失,在我看来它需要浮点数。

这是我目前(最初)对生成器自定义损失的思考:

y_dim是类的数量

disc_class_layer= 在给定输入图像的情况下输出样式/类的 FC 层

for循环尝试等效于:

(sigma k=1 直到 k) ((1/K)log(Dc(ck|G(z)) + (1 - (1/K)log(1 - Dc(ck|G(z)))。

class CanGLoss(nn.Module):
def __init__(self,y_dim,labels,disc_class_layer):
    super(CanGLoss,self).__init__()

def forward(self,inp):
    style_loss = 0
    for i in range(1,y_dim+1):
        style_loss += (1/i)*torch.log(disc_class_layer(inp)) + (1 - (1/i))*torch.log(1-disc_class_layer(inp))
    return style_loss*-1

这是在正确的轨道上吗?我是自定义损失函数和 Pytorch 的新手,不确定这是要走的路。

任何帮助都会很棒!

标签: pythonmachine-learningdeep-learningpytorch

解决方案


推荐阅读