machine-learning - My DC-GAN on grayscale face images is not training well
问题描述
So I trained by python/pytorch DC-GAN (deep convolutional GAN) for 30 epochs on grayscale faces, and my GAN pretty much failed. I added batch normalization and leaky relu's to my generator and discriminator (I heard those are ways to make the GAN converge), and the Adam optimizer. My GAN still only putting out random grayscale pixels (nothing even remotely related to faces.) I have no problem with the discriminator, my discriminator works very well. I then implemented weight decay of 0.01 on my discriminator to make my GAN train better (since my discriminator was doing better than my generator) but to no avail. Finally, I tried training the GAN for more epochs, 60 epochs. My GAN still generates just random pixels, sometimes outputting completely black. The GAN training method I used worked for the MNIST dataset (but I used a way simpler GAN architecture for that.)
import torch.nn as nn
import torch.nn.functional as F
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 4, 3)
self.conv2 = nn.Conv2d(4, 8, 3)
self.bnorm1 = nn.BatchNorm2d(8)
self.conv3 = nn.Conv2d(8, 16, 3)
self.conv4 = nn.Conv2d(16, 32, 3)
self.bnorm2 = nn.BatchNorm2d(32)
self.conv5 = nn.Conv2d(32, 4, 3)
self.fc1 = nn.Linear(5776, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 1)
def forward(self, x):
pred = F.leaky_relu(self.conv1(x.reshape(-1,1,48,48)))
pred = F.leaky_relu(self.bnorm1(self.conv2(pred)))
pred = F.leaky_relu(self.conv3(pred))
pred = F.leaky_relu(self.bnorm2(self.conv4(pred)))
pred = F.leaky_relu(self.conv5(pred))
pred = pred.reshape(-1, 5776)
pred = F.leaky_relu(self.fc1(pred))
pred = F.leaky_relu(self.fc2(pred))
pred = torch.sigmoid(self.fc3(pred))
return pred
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 5776)
self.convT1 = nn.ConvTranspose2d(4, 32, 3)
self.convT2 = nn.ConvTranspose2d(32, 16, 3)
self.bnorm1 = nn.BatchNorm2d(16)
self.convT3 = nn.ConvTranspose2d(16, 8, 3)
self.convT4 = nn.ConvTranspose2d(8, 4, 3)
self.bnorm2 = nn.BatchNorm2d(4)
self.convT5 = nn.ConvTranspose2d(4, 1, 3)
def forward(self, x):
pred = F.leaky_relu(self.fc1(x))
pred = F.leaky_relu(self.fc2(pred))
pred = F.leaky_relu(self.fc3(pred))
pred = pred.reshape(-1, 4, 38, 38)
pred = F.leaky_relu(self.convT1(pred))
pred = F.leaky_relu(self.bnorm1(self.convT2(pred)))
pred = F.leaky_relu(self.convT3(pred))
pred = F.leaky_relu(self.bnorm2(self.convT4(pred)))
pred = torch.sigmoid(self.convT5(pred))
return pred
import torch.optim as optim
discriminator = discriminator.to("cuda")
generator = generator.to("cuda")
discriminator_losses = []
generator_losses = []
for epoch in range(30):
for data,label in tensor_dataset:
data = data.to("cuda")
label = label.to("cuda")
batch_size = data.size(0)
real_labels = torch.ones(batch_size, 1).to("cuda")
fake_labels = torch.zeros(batch_size, 1).to("cuda")
noise = torch.randn(batch_size, 512).to("cuda")
D_real = discriminator(data)
D_fake = discriminator(generator(noise))
D_real_loss = F.binary_cross_entropy(D_real, real_labels)
D_fake_loss = F.binary_cross_entropy(D_fake, fake_labels)
D_loss = D_real_loss+D_fake_loss
d_optim.zero_grad()
D_loss.backward()
d_optim.step()
noise = torch.randn(batch_size, 512).to("cuda")
D_fake = discriminator(generator(noise))
G_loss = F.binary_cross_entropy(D_fake, real_labels)
g_optim.zero_grad()
G_loss.backward()
g_optim.step()
discriminator_losses.append(D_loss)
generator_losses.append(G_loss)
print(epoch)
解决方案
我也是深度学习和 GAN 模型的新手,但这种方法为我的 DCGAN 项目解决了类似的问题。使用至少 4*4 的内核大小:这是我的猜测,但无论网络有多深,小内核似乎都无法捕捉图像中的细节。我发现的其他提示大多来自这里:(上面的相同链接) https://machinelearningmastery.com/how-to-train-stable-generation-adversarial-networks/
推荐阅读
- python - 使用 matplotlib.animation.FuncAnimation 一次只显示一帧
- c - 从用户输入 n 显示结果 n 次
- r - 如果给定年份的所有观测值都是 NA,如何删除面板数据中的变量?
- javascript - 如何在three.js中获取VR头显的位置
- c++ - 如何使用 stl 迭代器和队列来实现图的广度优先遍历?
- android - 如何使用 pub sub push 实现 JWT
- sql-server - 如何返回执行结果和总行数
- javascript - 嵌套异步/等待方法的正确方法
- firewalld - 如何检查哪些设置或规则阻止了防火墙的访问
- angular - 使用来自另一个 repo 的 Angular 库而不部署它