首页 > 解决方案 > 为什么生成器损失为 0.00 但仍然在 GAN 中产生白色图像?

问题描述

我创建了一个 GAN(生成对抗网络)来创建 CIFAR-100 图像。该模型运行良好,但产生白色图像

我的代码如下(Colab 笔记本):

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(device)

tfs = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = datasets.CIFAR100('.', train = True, transform = tfs, download = True)
test_data = datasets.CIFAR100('.', train = False, transform = tfs, download = False)

idx = list(i for i in range(0, 10000))
train_data = torch.utils.data.Subset(train_data, idx)

batch_size = 100
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers = 4)
test_loader = DataLoader(test_data, batch_size = batch_size, shuffle = True, num_workers = 4)

class conv_layer(nn.Module):
  def __init__(self, in_channels, out_channels, ks, strides = 1, padding = 1,
               normalize = True, dropout = 0.0):
    super(conv_layer, self).__init__()
    layer = nn.Conv2d(in_channels, out_channels, ks, strides, padding)
    
    # Xavier He init
    nn.init.xavier_normal_(layer.weight)
    
    layers = [layer]
    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.ReLU())
    if dropout:
      layers.append(nn.Dropout(dropout))
    
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)

class Discriminator(nn.Module):
  def __init__(self, img_channels):
    super(Discriminator, self).__init__()

    self.l1 = conv_layer(img_channels, 64, 3)
    self.l2 = conv_layer(64, 128, 3, 2, 0)
    self.l3 = conv_layer(128, 256, 3, 2, 0)
    self.l4 = conv_layer(256, 512, 3, 2, 0)
    self.l5 = conv_layer(512, 1, 3, 2, 0, dropout = 0.5)
    self.l6 = nn.Sigmoid()

  def forward(self, x):
    out = self.l1(x)
    out = self.l2(out)
    out = self.l3(out)
    out = self.l4(out)
    out = self.l5(out)
    out = self.l6(out)

    return out

class deconv_layer(nn.Module):
  def __init__(self, in_channels, out_channels, ks, strides = 1, padding = 1,
               normalize = True):
    super(deconv_layer, self).__init__()
    conv_layer = nn.ConvTranspose2d(in_channels, out_channels, ks, strides, padding)
    
    # Xavier He init
    nn.init.xavier_normal_(conv_layer.weight)
    
    layers = [conv_layer]
    if normalize:
      layers.append(nn.InstanceNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2))
    
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)

class Generator(nn.Module):
  def __init__(self, latent_dim, out_channel):
    super(Generator, self).__init__()

    self.l1 = nn.Linear(latent_dim, 256*4*4)
    self.l2 = deconv_layer(256, 128, 4, 2)
    self.l3 = deconv_layer(128, 128, 4, 2)
    self.l4 = deconv_layer(128, 128, 4, 2)
    self.l5 = nn.Conv2d(128, out_channel, 3, 1, 1)
    self.l6 = nn.Sigmoid()

  def forward(self, x):
    out = self.l1(x)
    out = out.reshape(-1, 256, 4, 4)
    out = self.l2(out)
    out = self.l3(out)
    out = self.l4(out)
    out = self.l5(out)
    out = self.l6(out)

    return out


image_shape = (3, 32, 32)
num_classes = 100
lr = 0.01
num_epochs = 500

latent_dim = 100
out_channel = image_shape[0]

gen = Generator(latent_dim, out_channel).to(device)
disc = Discriminator(out_channel).to(device)
optim_gen = torch.optim.Adam(gen.parameters(), lr = lr, weight_decay = 5e-4)
optim_disc = torch.optim.Adam(disc.parameters(), lr = lr, weight_decay = 3e-4)
criterion = nn.BCELoss()
criterion_gan = nn.CrossEntropyLoss()

import os
import shutil
if os.path.exists('log/fake'):
  shutil.rmtree('log/fake')
if os.path.exists('log/real'):
  shutil.rmtree('log/real')

writer_fake = SummaryWriter('log/fake')
writer_real = SummaryWriter('log/real')
step = 0


%reload_ext tensorboard
%tensorboard --logdir log

z_dim = 100
for epoch in range(num_epochs):
    for id, (real, label) in enumerate(train_loader):       
        real = real.to(device)
        batch_size = real.shape[0]

        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc.zero_grad()
        
        # output of discriminator(real)
        disc_real = disc(real).view(-1)
        loss_real = criterion(disc_real, torch.ones_like(disc_real))

        # output of discriminator(fake)
        disc_fake = disc(fake.detach()).view(-1)
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        loss_disc = (loss_real + loss_fake) / 2
        
        loss_disc.backward()
        optim_disc.step()
        
        # Train generator
        gen.zero_grad()
        gen_image = gen(noise)
        loss_gen = criterion(gen_image, torch.ones_like(gen_image))
        
        loss_gen.backward()        
        optim_gen.step()
        


        if id == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {id}/{len(train_loader)} \
                      Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
              fixed_noise = torch.randn((batch_size, z_dim)).to(device)
              fake = gen(fixed_noise)
              data = real
              img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
              img_grid_real = torchvision.utils.make_grid(data, normalize=True)

              writer_fake.add_image(
                  "Fake Images", img_grid_fake, global_step=step
              )
              writer_real.add_image(
                  "Real Images", img_grid_real, global_step=step
              )
              step += 1

我已经运行了几个时期,这是日志:

Epoch [0/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.7558
Epoch [1/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [2/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0010
Epoch [3/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0002
Epoch [4/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0004
Epoch [5/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [6/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0007
Epoch [7/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [8/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [9/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [10/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [11/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [12/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [13/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [14/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [15/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [16/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [17/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [18/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [19/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [20/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [21/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [22/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [23/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [24/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [25/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [26/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [27/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [28/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0002
Epoch [29/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [30/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [31/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [32/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [33/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [34/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0004
Epoch [35/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [36/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0004
Epoch [37/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [38/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [39/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [40/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [41/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [42/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [43/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0002
Epoch [44/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [45/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [46/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [47/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [48/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [49/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0002
Epoch [50/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [51/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0008
Epoch [52/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [53/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [54/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [55/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [56/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [57/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [58/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [59/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [60/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [61/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [62/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [63/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [64/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [65/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [66/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001
Epoch [67/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [68/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [69/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [70/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [71/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [72/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0002
Epoch [73/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0000
Epoch [74/500] Batch 0/100                       Loss D: 0.6931, loss G: 0.0001

如您所见,鉴别器和生成器的损失都停滞不前。产生的图像是白色空白图像。我无法理解这个结果的原因。请解释。谢谢你。

标签: deep-learningpytorchloss-functiongenerative-adversarial-network

解决方案


推荐阅读