首页 > 解决方案 > 如何为 StyleGAN2 插值编写优化器?

问题描述

我想使用 NVLabs 的 StyleGAN2-ADA-PyTorch 插入两个图像。为简单起见,可以说,对于不同人的两个图像,我想创建第三个图像来描绘第三个人,第一个图像中的身体,第二个图像中的头部。我还准备好手头的两个图像的相应 w 向量。

# G is a generative model in line with StyleGAN2, trained to output 512x512 images.
# Latents shape is [1, 16, 512]
G = G.eval().requires_grad_(False).to(device) # type: ignore
num_ws = G.mapping.num_ws # 16
w_dim = G.mapping.w_dim # 512

# Segmentation network is used to extract important parts from images
segmentation_dnn = segmentation_dnn.to(device)

# Source images are represented as latent vectors. I use G to generate actual images:
image_body = image_from_output(G.synthesis(w_body, noise_mode='const'))
image_head = image_from_output(G.synthesis(w_head, noise_mode='const'))

# Custom function is applied to source images, creating masked images. 
# In masked images, only head or body is present (and the rest is filled with white pixels)
image_body_masked = apply_segmentation_mask(image_body, segmentation_dnn, select='body')
image_head_masked = apply_segmentation_mask(image_head, segmentation_dnn, select='head')

为了比较任意两张图片的相似度,我使用 VGGLos

# VGG16 is used as a feature extractor to evaluate image similarity
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with dnnlib.util.open_url(url) as f:
    vgg16 = torch.jit.load(f).eval().to(device)

class VGGLoss(nn.Module):
    def __init__(self, device, vgg):
        super().__init__()
        
        for param in self.parameters():
            param.requires_grad = False

        self.criterion = nn.L1Loss().to(device)
        
    def forward(self, source, target):
        loss = 0 
        source_features = self.vgg(source, resize_images=False, return_lpips=True)
        target_features = self.vgg(target, resize_images=False, return_lpips=True)
        loss += self.criterion(source, target)
            
        return loss 

vgg_loss = VGGLoss(device, vgg=vgg16)

现在,我想插入image_bodyimage_head创建image_target. 为此,我需要在 StyleGAN2 的潜在空间中找到 image_target 的潜在表示。粗略地说,我们可以使用优化系数query_opt来部分包含来自image_body和的潜在image_headw_target = w_body + (query_opt * (w_head - w_person))

query_opt = torch.randn([1, num_ws, 1], dtype=torch.float32, device=device, requires_grad=True)
optimizer = torch.optim.Adam(query_opt, betas=(0.9, 0.999), lr=initial_learning_rate)

w_out = []
for step in num_steps:
    # Learning rate schedule.
    t = step / num_steps
    lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
    lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
    lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
    lr = initial_learning_rate * lr_ramp
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Synth image from w_target using query_opt.
    # This interpolation formula is an important step, and I think my math might be out of order up here

    w_target = w_body + (query_opt * (w_head - w_person))
    image_target = image_from_output(G.synthesis(ws, noise_mode='const'))
    image_target_body_masked = apply_segmentation_mask(image_target, segmentation_dnn, select='body')
    image_target_head_masked = apply_segmentation_mask(image_target, segmentation_dnn, select='head')
    loss = vgg_loss(image_body_masked, image_target_body_masked) + vgg_loss(image_head_masked, image_target_head_masked)
    
    # Step
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    logprint(f'step {step+1:>4d}/{num_steps}: loss {float(loss):<5.2f}')

    # Save current w_target
    w_out[step] = w_target.detach()

我不知道如何使我的优化器以实际优化组合 VGGloss 的方式实际定位 query_opt。我必须在我的 PyTorch 代码中遗漏一些东西,或者甚至在主插值公式中。

标签: machine-learningoptimizationpytorchlinear-algebragenerative-adversarial-network

解决方案


推荐阅读