首页 > 解决方案 > 更新模型一部分的权重 (nn.Module)

问题描述

我在构建一个松散地基于 CycleGAN 架构的网络时遇到了一个问题

我把它的所有组件都放在一个里面nn.Module

from torch import nn

from classes.EncoderDecoder import EncoderDecoder
from classes.Discriminator import Discriminator

class CycleGAN(nn.Module):
    def __init__(self):
        super(CycleGAN, self).__init__()
        self.encdec1 = EncoderDecoder(encoder_in_channels=3)
        self.encdec2 = EncoderDecoder(encoder_in_channels=3)
        self.disc = Discriminator()
        

    def forward(self, images, images_bw):

        disc_color = self.disc(images) # I want the Discriminator to be trained here
        disc_bw = self.disc(images_bw) # I want the Discriminator to be trained here

        decoded1 = self.encdec1(images_bw) # EncoderDecoder forward pass
        decoded2 = self.encdec2(decoded1)

        decoded_disc = self.disc(decoded1)  # I don't want to train the Discriminator here, 
                                            # only the EncoderDecoder should be trained based
                                            # on this Discriminator's result

        return [disc_color, disc_bw, decoded1, decoded2, decoded_disc]

这就是我初始化这个模块、损失函数和优化器的方式

c_gan = CycleGAN().to('cuda', dtype=float32, non_blocking=True)

l2_loss = MSELoss().to('cuda', dtype=float32).train()
bce_loss = BCELoss().to('cuda', dtype=float32).train()

optimizer_gan = Adam(c_gan.parameters(), lr=0.00001)

这就是我在训练循环中训练网络的方式

c_gan.zero_grad()
optimizer_gan.zero_grad()

disc_color, disc_bw, decoded1, decoded2, decoded_disc = c_gan(images, images_bw)

loss_true = bce_loss(disc_color, label_true)
loss_false = bce_loss(disc_bw, label_false)
disc_loss = loss_true + loss_false
disc_loss.backward()

decoded_loss = l2_loss(decoded2, images_bw)
decoded_disc_loss = bce_loss(decoded_disc, label_true) # This is where the loss for that Discriminator forward pass is calculated
both_decoded_losses = decoded_loss + decoded_disc_loss
both_decoded_losses.backward()
optimizer_gan.step()

问题

我不想Discriminator根据EncoderDecoder -> Discriminator前向传递来训练模块。但是,我确实想根据images -> Discriminatorimages_bw -> Discriminator前传来训练它。

我将不胜感激任何帮助。

标签: pythonmachine-learningpytorch

解决方案


来自PyTorch 示例:冻结网络的一部分(包括微调) - GitHub gist

class CycleGan:
    ...

c_gan = CycleGan()
# freeze every layer of discriminator
# c_gan.disc.{layer}.weight.requires_grad = False
# c_gan.disc.{layer}.bias.requires_grad = False

...

推荐阅读