首页 > 解决方案 > 训练pytorch模型后如何丢弃分支

问题描述

我正在尝试实现一个 FCN,pytorch其整体结构如下:

在此处输入图像描述

到目前为止的代码如下所示:

class SNet(nn.Module):
    def __init__(self):
        super(SNet, self).__init__()
        
        self.enc_a = encoder(...)
        self.dec_a = decoder(...)
        
        self.enc_b = encoder(...)
        self.dec_b = decoder(...)
    
    def forward(self, x1, x2):
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
        return x1, x2

keras使用函数式 API 相对容易做到这一点。但是,我在pytorch.

  1. 训练后如何丢弃dec_a(自动编码器分支的解码器部分)?
  2. 在联合训练期间,两个分支loss的总和(可选加权) ?loss

标签: machine-learningpytorchautoencoder

解决方案


您还可以为模型定义单独的模式以进行训练和推理:

class SNet(nn.Module):
  def __init__(self):
    super(SNet, self).__init__()
    
    self.enc_a = encoder(...)
    self.dec_a = decoder(...)
    
    self.enc_b = encoder(...)
    self.dec_b = decoder(...)
    
    self.training = True

  def forward(self, x1, x2):
    if self.training:
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
        return x1, x2
    else:
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        return x2

这些块是示例,可能无法完全按照您的意愿进行操作,因为我认为您在块图中定义训练和推理操作的方式与您的代码之间存在一些歧义,但无论如何您都知道如何只能在训练模式下使用某些模块。然后你可以相应地设置这个变量。


推荐阅读