deep-learning - 为什么生成器损失为 0.00 但仍然在 GAN 中产生白色图像?
问题描述
我创建了一个 GAN(生成对抗网络)来创建 CIFAR-100 图像。该模型运行良好,但产生白色图像
我的代码如下(Colab 笔记本):
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(device)
tfs = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = datasets.CIFAR100('.', train = True, transform = tfs, download = True)
test_data = datasets.CIFAR100('.', train = False, transform = tfs, download = False)
idx = list(i for i in range(0, 10000))
train_data = torch.utils.data.Subset(train_data, idx)
batch_size = 100
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers = 4)
test_loader = DataLoader(test_data, batch_size = batch_size, shuffle = True, num_workers = 4)
class conv_layer(nn.Module):
def __init__(self, in_channels, out_channels, ks, strides = 1, padding = 1,
normalize = True, dropout = 0.0):
super(conv_layer, self).__init__()
layer = nn.Conv2d(in_channels, out_channels, ks, strides, padding)
# Xavier He init
nn.init.xavier_normal_(layer.weight)
layers = [layer]
if normalize:
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.ReLU())
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self, img_channels):
super(Discriminator, self).__init__()
self.l1 = conv_layer(img_channels, 64, 3)
self.l2 = conv_layer(64, 128, 3, 2, 0)
self.l3 = conv_layer(128, 256, 3, 2, 0)
self.l4 = conv_layer(256, 512, 3, 2, 0)
self.l5 = conv_layer(512, 1, 3, 2, 0, dropout = 0.5)
self.l6 = nn.Sigmoid()
def forward(self, x):
out = self.l1(x)
out = self.l2(out)
out = self.l3(out)
out = self.l4(out)
out = self.l5(out)
out = self.l6(out)
return out
class deconv_layer(nn.Module):
def __init__(self, in_channels, out_channels, ks, strides = 1, padding = 1,
normalize = True):
super(deconv_layer, self).__init__()
conv_layer = nn.ConvTranspose2d(in_channels, out_channels, ks, strides, padding)
# Xavier He init
nn.init.xavier_normal_(conv_layer.weight)
layers = [conv_layer]
if normalize:
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Generator(nn.Module):
def __init__(self, latent_dim, out_channel):
super(Generator, self).__init__()
self.l1 = nn.Linear(latent_dim, 256*4*4)
self.l2 = deconv_layer(256, 128, 4, 2)
self.l3 = deconv_layer(128, 128, 4, 2)
self.l4 = deconv_layer(128, 128, 4, 2)
self.l5 = nn.Conv2d(128, out_channel, 3, 1, 1)
self.l6 = nn.Sigmoid()
def forward(self, x):
out = self.l1(x)
out = out.reshape(-1, 256, 4, 4)
out = self.l2(out)
out = self.l3(out)
out = self.l4(out)
out = self.l5(out)
out = self.l6(out)
return out
image_shape = (3, 32, 32)
num_classes = 100
lr = 0.01
num_epochs = 500
latent_dim = 100
out_channel = image_shape[0]
gen = Generator(latent_dim, out_channel).to(device)
disc = Discriminator(out_channel).to(device)
optim_gen = torch.optim.Adam(gen.parameters(), lr = lr, weight_decay = 5e-4)
optim_disc = torch.optim.Adam(disc.parameters(), lr = lr, weight_decay = 3e-4)
criterion = nn.BCELoss()
criterion_gan = nn.CrossEntropyLoss()
import os
import shutil
if os.path.exists('log/fake'):
shutil.rmtree('log/fake')
if os.path.exists('log/real'):
shutil.rmtree('log/real')
writer_fake = SummaryWriter('log/fake')
writer_real = SummaryWriter('log/real')
step = 0
%reload_ext tensorboard
%tensorboard --logdir log
z_dim = 100
for epoch in range(num_epochs):
for id, (real, label) in enumerate(train_loader):
real = real.to(device)
batch_size = real.shape[0]
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
disc.zero_grad()
# output of discriminator(real)
disc_real = disc(real).view(-1)
loss_real = criterion(disc_real, torch.ones_like(disc_real))
# output of discriminator(fake)
disc_fake = disc(fake.detach()).view(-1)
loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
loss_disc = (loss_real + loss_fake) / 2
loss_disc.backward()
optim_disc.step()
# Train generator
gen.zero_grad()
gen_image = gen(noise)
loss_gen = criterion(gen_image, torch.ones_like(gen_image))
loss_gen.backward()
optim_gen.step()
if id == 0:
print(
f"Epoch [{epoch}/{num_epochs}] Batch {id}/{len(train_loader)} \
Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
)
with torch.no_grad():
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
fake = gen(fixed_noise)
data = real
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Real Images", img_grid_real, global_step=step
)
step += 1
我已经运行了几个时期,这是日志:
Epoch [0/500] Batch 0/100 Loss D: 0.6931, loss G: 0.7558
Epoch [1/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [2/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0010
Epoch [3/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0002
Epoch [4/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0004
Epoch [5/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [6/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0007
Epoch [7/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [8/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [9/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [10/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [11/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [12/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [13/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [14/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [15/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [16/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [17/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [18/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [19/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [20/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [21/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [22/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [23/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [24/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [25/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [26/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [27/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [28/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0002
Epoch [29/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [30/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [31/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [32/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [33/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [34/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0004
Epoch [35/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [36/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0004
Epoch [37/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [38/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [39/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [40/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [41/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [42/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [43/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0002
Epoch [44/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [45/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [46/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [47/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [48/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [49/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0002
Epoch [50/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [51/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0008
Epoch [52/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [53/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [54/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [55/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [56/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [57/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [58/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [59/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [60/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [61/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [62/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [63/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [64/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [65/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [66/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
Epoch [67/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [68/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [69/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [70/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [71/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [72/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0002
Epoch [73/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0000
Epoch [74/500] Batch 0/100 Loss D: 0.6931, loss G: 0.0001
如您所见,鉴别器和生成器的损失都停滞不前。产生的图像是白色空白图像。我无法理解这个结果的原因。请解释。谢谢你。
解决方案
推荐阅读
- android - 我从相机捕获图像它工作正常但是当我再次打开我的应用程序图像不保存
- progressive-web-apps - 无法使用可在移动设备上运行的经过测试的清单来触发 pwa(渐进式 Web 应用程序)的 Windows 桌面图标安装
- firebase - 使用 admob 连续显示视频广告
- python - 使用 pyserial 打开设备
- vue.js - 将模型类放在 Vuex 项目结构中的什么位置?
- python - 从作为字典的 Pandas 列中提取值
- visual-studio-code - 在 julia VSCode 中运行脚本
- python - 为什么我的变量返回无,即使被定义
- autodesk-forge - 如何在伪造查看器中获得模型的正确坐标
- c - 通过仅发送更改的值来减少 Firebase 延迟?