python - 为什么判别器和生成器的损失没有变化?
问题描述
我正在尝试为 MNIST 数据集实现生成对抗网络 (GAN)。我为此使用 Pytorch。我的问题是,在一个时代之后,鉴别器和生成器的损失并没有改变。
我已经尝试了其他两种方法来构建网络,但它们会导致所有相同的问题:/
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as grd
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision #Datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.autograd import Variable
import pylab
#Parameter
batch_size = 64
epochs = 50000
image_size = 784
hidden_size = 392
sample_dir = 'samples'
save_dir = 'save'
noise_size = 100
lr = 0.001
# Image processing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))])
# Discriminator
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
# Generator
G = nn.Sequential(
nn.Linear(noise_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Sigmoid()
)
# Lossfunction and optimizer (sigmoid cross entropy with logits and Adam)
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr = lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr = lr)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# Statistics to be saved
d_losses = np.zeros(epochs)
g_losses = np.zeros(epochs)
real_scores = np.zeros(epochs)
fake_scores = np.zeros(epochs)
# Start training
total_step = len(data_loader)
for epoch in range(epochs):
for i, (images, _) in enumerate(data_loader):
if images.shape[0] != 64:
continue
images = images.view(batch_size, -1).cuda()
images = Variable(images)
# Create the labels which are later used as input for the BCE loss
real_labels = torch.ones(batch_size, 1).cuda()
real_labels = Variable(real_labels)
fake_labels = torch.zeros(batch_size, 1).cuda()
fake_labels = Variable(fake_labels)
# Train discriminator
# Compute BCE_WithLogitsLoss using real images
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# Compute BCE_WithLogitsLoss using fake images
# First term of the loss is always zero since fake_labels == 0
z = torch.randn(batch_size, noise_size).cuda()
z = Variable(z)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# Backprop and optimize
# If D is trained so well, then don't update
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
# Train generator
# Compute loss with fake images
z = torch.randn(batch_size, noise_size).cuda()
z = Variable(z)
fake_images = G(z)
outputs = D(fake_images)
# We train G to maximize log(D(G(z)) instead of minimizing log(1 -D(G(z)))
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)
# Backprop and optimize
# if G is trained so well, then don't update
reset_grad()
g_loss.backward()
g_optimizer.step()
# Update statistics
d_losses[epoch] = d_losses[epoch]*(i/(i+1.)) + d_loss.item()*(1./(i+1.))
g_losses[epoch] = g_losses[epoch]*(i/(i+1.)) + g_loss.item()*(1./(i+1.))
real_scores[epoch] = real_scores[epoch]*(i/(i+1.)) + real_score.mean().item()*(1./(i+1.))
fake_scores[epoch] = fake_scores[epoch]*(i/(i+1.)) + fake_score.mean().item()*(1./(i+1.))
# print results
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, epochs, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
生成器和判别器的损失应该随着时代的变化而变化,但它们不会。
Epoch [0/50000], d_loss: 1.0069, g_loss: 0.6927, D(x): 1.00, D(G(z)): 0.00
Epoch [1/50000], d_loss: 1.0065, g_loss: 0.6931, D(x): 1.00, D(G(z)): 0.00
Epoch [2/50000], d_loss: 1.0064, g_loss: 0.6931, D(x): 1.00, D(G(z)): 0.00
Epoch [3/50000], d_loss: 1.0064, g_loss: 0.6931, D(x): 1.00, D(G(z)): 0.00
Epoch [4/50000], d_loss: 1.0064, g_loss: 0.6931, D(x): 1.00, D(G(z)): 0.00
Epoch [5/50000], d_loss: 1.0064, g_loss: 0.6931, D(x): 1.00, D(G(z)): 0.00
谢谢你的帮助。
解决方案
我找到了问题的解决方案。BCEWithLogitsLoss() 和 Sigmoid() 不能一起工作,因为 BCEWithLogitsLoss() 包括 Sigmoid 激活。所以你可以在没有 Sigmoid() 的情况下使用 BCEWithLogitsLoss() 或者你可以使用 Sigmoid() 和 BCELoss()
推荐阅读
- python-3.x - 特殊字符仅作为字符串的一部分打印,但不独立打印(python3)
- azure - .NetCore 2.2 API 在使用用户分配的身份时无法从 AAD 获取令牌
- android - 有没有办法用片段着色器在顶点着色器的一个点的位置画一个圆?
- asp.net-mvc - 如何在运行时从 MVC 应用程序将参数传递给报表服务器上的 Power BI 报表
- jest-fetch-mock - 如何使用 jest-fetch-mock 返回一个数组?
- javascript - 无法从 Material UI 中的卡片内容中删除 padding-bottom
- if-statement - 如何在嵌套的 ifelse 语句中做到这一点
- python - 无法从窗口中删除 Opencv 绘图
- audit.net - 如何在审计中包含特定列(用户 ID)
- c# - SQL Server CLR Int64 到 SQLInt64 指定的强制转换无效