首页 > 解决方案 > WGAN-GP 生成的图像看起来很灰

问题描述

这是一个已解决的问题,原因在于高分辨率。


各位,下面是我的WGAN-GP生成的图片,但是看起来不彩色?不知道怎么形容,btw,损失一直在增加……

结果

火车

数据是:

def data_preprocess():
    
    batch_size = 64

    data_transforms = transforms.Compose([transforms.Resize(size=(256,256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])])
    
    data_dir = "./256"
    data = ImageFolder(data_dir,transform = data_transforms)
    data_loader = Data.DataLoader(
        data,
        batch_size = batch_size,
        shuffle = True,
        num_workers = 0)
    
    return data_loader, data, batch_size

对于 G 和 D:

class Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.main = nn.Sequential( # Z: 100
            
            nn.ConvTranspose2d(100, 1024, 4, 2, 0),
            nn.BatchNorm2d(num_features=1024),
            nn.ReLU(True),

            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(True), # nn.ReLU(True)

            nn.ConvTranspose2d(32, 3, 4, 2, 1)) # Output of Main: (3,256,256)

        self.output = nn.Tanh()
        
    def weight_init(self,type):
        if type == "default":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)
        elif type == "kaiming":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)
        elif type == "xavier":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.xavier_normal_(m.weight.data, gain=1.0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)

    def forward(self, x):
        x = self.main(x)
        return self.output(x)

class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(

            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1), # (3,256,256)
            nn.LayerNorm((128,128)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(64),
            nn.LayerNorm((64,64)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(128),
            nn.LayerNorm((32,32)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(256),
            nn.LayerNorm((16,16)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            # nn.BatchNorm2d(512),
            nn.LayerNorm((8,8)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), # (1024,4,4)
            # nn.BatchNorm2d(1024),
            nn.LayerNorm((4,4)),
            nn.LeakyReLU(0.2, inplace=True))

        self.output = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=2, padding=0) # 1
            # nn.Sigmoid()
            )
        
    def weight_init(self,type):
        if type == "default":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)
        elif type == "kaiming":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='leaky_relu')
                elif isinstance(m, nn.LayerNorm):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)
        elif type == "xavier":
            for m in self.main:
                if isinstance(m, nn.Conv2d):
                    nn.init.xavier_normal_(m.weight.data, gain=1.0)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.normal_(m.weight.data, 0, 0.02)
                    nn.init.constant_(m.bias.data, 0)
                    
    def forward(self, x):
        x = self.main(x)suggestion
        return self.output(x)

GP功能:

def calculate_gradient_penalty(real_images, fake_images):
    
    t = torch.rand(real_images.size(0), 1, 1, 1).to(real_images.device)
    t = t.expand(real_images.size())

    interpolates = t * real_images + (1 - t) * fake_images
    interpolates.requires_grad_(True)

    disc_interpolates = D(interpolates)

    grad = torch.autograd.grad(outputs=disc_interpolates, 
                               inputs=interpolates,
                               grad_outputs=torch.ones_like(disc_interpolates),
                               create_graph=True, 
                               retain_graph=True)[0]

    loss_gp = (((grad.norm(2, dim=1) - 1) ** 2).mean()) * lambda_term # grad.norm(order, dim)
   
    return loss_gp

最后是训练部分:

G = Generator()
D = Discriminator()
G.weight_init("kaiming") # "default", "kaiming", "xavier"
D.weight_init("kaiming")
G.to(device)
D.to(device)

learning_rate = 1e-4 # 1e-4
lambda_term = 10
generator_iters = 10000
data_preprocess.batch_size = 64
critic_iter = 5 # 1 Generator, 5 Discriminator
record_lenth = 0

# loss = nn.BCELoss().to(device) # No Log Here
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(0.5, 0.999))
data_loader, data, batch_size = data_preprocess()

d_progress = []
d_fake_progress = []
d_real_progress = []
penalty = []
g_progress = []

data = get_infinite_batches(data_loader)
one = torch.FloatTensor([1]).to(device) 
mone = (one * -1).to(device) 

for g_iter in range(generator_iters):
    
    print('----------G Iter-{}----------'.format(g_iter+1))
    
    for p in D.parameters():
        p.requires_grad = True # This is by Default
        
    d_loss_real = 0
    d_loss_fake = 0
    Wasserstein_D = 0

    for d_iter in range(critic_iter):
        D.zero_grad()
            
        images = data.__next__()
        if images.size()[0] != batch_size:
            continue
        
        # Train Discriminator
        # Real Images
        images = images.to(device)
        z = torch.randn(batch_size, 100, 1, 1).to(device)
        d_real = D(images)
        d_loss_real = d_real.mean(0).view(1)
        d_loss_real.backward(mone)
        
        # Fake Images
        z = torch.randn(batch_size, 100, 1, 1).to(device) # ???
        fake_images = G(z)
        d_fake = D(fake_images)
        d_loss_fake = d_fake.mean(0).view(1)
        d_loss_fake.backward(one)
        
        # Calculate Penalty
        gradient_penalty = calculate_gradient_penalty(images.data, fake_images.data)
        gradient_penalty.backward()
        
        # Total Loss
        d_loss = d_loss_fake - d_loss_real + gradient_penalty
        Wasserstein_D = d_loss_real - d_loss_fake
        d_optimizer.step()
        print('D Loss: %.6s, Fake: %.6s, Real: %.6s, Penalty: %.6s' %(d_loss.item(),d_loss_fake.item(),d_loss_real.item(),gradient_penalty.item())) 
        
        time.sleep(0.1)
        d_progress.append(d_loss.item()) # Store Loss
        d_fake_progress.append(d_loss_fake.item())
        d_real_progress.append(d_loss_real.item())
        penalty.append(gradient_penalty.item())
        
        record_lenth += 1
        writer.add_scalars('Continue Test D Loss', {'D Loss': d_loss, # Multiply in One
                                 'D Loss Fake': d_loss_fake,
                                 'D Loss Real': d_loss_real}, record_lenth+1)
    
    # Generator Updata
    for p in D.parameters():
        p.requires_grad = False  # Avoid Computation
    
    # Train Generator
    # Compute with Fake
    G.zero_grad()
    z = torch.randn(batch_size, 100, 1, 1).to(device)
    fake_images = G(z)
    d_fake = D(fake_images)
    g_loss = d_fake.mean(0).view(1)
    g_loss.backward(mone) 
    g_cost = -g_loss
    g_optimizer.step()
    print('G Loss: %.6s'% g_loss.item()) 
        
    g_progress.append(g_loss.item()) # Store Loss
    writer.add_scalar('Continue Test G Loss', g_loss, g_iter+1) # Tensor Board

我不知道为什么会发生这种情况,我曾尝试训练更多迭代 - 10000 次迭代并没有变得更好......我仍然会尽力解决这个问题,但我是 GAN 的新手,所以任何人都应该有任何建议,请提前告诉我,谢谢!希望每个人都有美好的一天!

标签: pythondeep-learningpytorchcomputer-visiongenerative-adversarial-network

解决方案


您正在应用以下转换:

transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])

由于之前的变换是transforms.ToTensor,它将输出[0, 1]范围内的值。考虑到您使用transforms.Normalize的是 mean0.5和 std 0.5,您的输入图像最终将在 [-1, 1] 范围内,这可能解释了图像上的灰色外观。根据您的数据集以及您是否使用批处理规范,可能值得更改这些值以反映数据集的统计信息。


推荐阅读