python - 更新模型一部分的权重 (nn.Module)
问题描述
我在构建一个松散地基于 CycleGAN 架构的网络时遇到了一个问题
我把它的所有组件都放在一个里面nn.Module
from torch import nn
from classes.EncoderDecoder import EncoderDecoder
from classes.Discriminator import Discriminator
class CycleGAN(nn.Module):
def __init__(self):
super(CycleGAN, self).__init__()
self.encdec1 = EncoderDecoder(encoder_in_channels=3)
self.encdec2 = EncoderDecoder(encoder_in_channels=3)
self.disc = Discriminator()
def forward(self, images, images_bw):
disc_color = self.disc(images) # I want the Discriminator to be trained here
disc_bw = self.disc(images_bw) # I want the Discriminator to be trained here
decoded1 = self.encdec1(images_bw) # EncoderDecoder forward pass
decoded2 = self.encdec2(decoded1)
decoded_disc = self.disc(decoded1) # I don't want to train the Discriminator here,
# only the EncoderDecoder should be trained based
# on this Discriminator's result
return [disc_color, disc_bw, decoded1, decoded2, decoded_disc]
这就是我初始化这个模块、损失函数和优化器的方式
c_gan = CycleGAN().to('cuda', dtype=float32, non_blocking=True)
l2_loss = MSELoss().to('cuda', dtype=float32).train()
bce_loss = BCELoss().to('cuda', dtype=float32).train()
optimizer_gan = Adam(c_gan.parameters(), lr=0.00001)
这就是我在训练循环中训练网络的方式
c_gan.zero_grad()
optimizer_gan.zero_grad()
disc_color, disc_bw, decoded1, decoded2, decoded_disc = c_gan(images, images_bw)
loss_true = bce_loss(disc_color, label_true)
loss_false = bce_loss(disc_bw, label_false)
disc_loss = loss_true + loss_false
disc_loss.backward()
decoded_loss = l2_loss(decoded2, images_bw)
decoded_disc_loss = bce_loss(decoded_disc, label_true) # This is where the loss for that Discriminator forward pass is calculated
both_decoded_losses = decoded_loss + decoded_disc_loss
both_decoded_losses.backward()
optimizer_gan.step()
问题
我不想Discriminator
根据EncoderDecoder -> Discriminator
前向传递来训练模块。但是,我确实想根据images -> Discriminator
和images_bw -> Discriminator
前传来训练它。
CycleGAN
是否可以只为我的模块使用一个优化器来实现这一点?- 我可以
Discriminator
在优化器期间冻结.step()
吗?
我将不胜感激任何帮助。
解决方案
来自PyTorch 示例:冻结网络的一部分(包括微调) - GitHub gist
class CycleGan:
...
c_gan = CycleGan()
# freeze every layer of discriminator
# c_gan.disc.{layer}.weight.requires_grad = False
# c_gan.disc.{layer}.bias.requires_grad = False
...
推荐阅读
- sql - 假脱机输出被拆分
- c# - 无法选择要发布的 Visual Studio 配置文件
- ios - Xcode interface Builder 没有给我选择在哪里放置 UITableViewCell 的重用标识符
- java - 将结果集转换为字符串
- php - Laravel 本地范围在急切加载
- php - Xampp 虚拟主机错误 500
- r - 计算问题意味着在 R 中使用 tapply()
- angular - 构建失败 com.android.ide.common.process.ProcessException:无法执行 aapt
- javascript - 用于自动完成/自动填充的 Javascript 代码会减慢调查速度
- c# - 将控件绑定到对象是否算作订阅?