machine-learning - 训练pytorch模型后如何丢弃分支
问题描述
我正在尝试实现一个 FCN,pytorch
其整体结构如下:
到目前为止的代码如下所示:
class SNet(nn.Module):
def __init__(self):
super(SNet, self).__init__()
self.enc_a = encoder(...)
self.dec_a = decoder(...)
self.enc_b = encoder(...)
self.dec_b = decoder(...)
def forward(self, x1, x2):
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
return x1, x2
keras
使用函数式 API 相对容易做到这一点。但是,我在pytorch
.
- 训练后如何丢弃
dec_a
(自动编码器分支的解码器部分)? - 在联合训练期间,两个分支
loss
的总和(可选加权) ?loss
解决方案
您还可以为模型定义单独的模式以进行训练和推理:
class SNet(nn.Module):
def __init__(self):
super(SNet, self).__init__()
self.enc_a = encoder(...)
self.dec_a = decoder(...)
self.enc_b = encoder(...)
self.dec_b = decoder(...)
self.training = True
def forward(self, x1, x2):
if self.training:
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
return x1, x2
else:
x1 = self.enc_a(x1)
x2 = self.enc_b(x2)
x2 = self.dec_b(x2)
return x2
这些块是示例,可能无法完全按照您的意愿进行操作,因为我认为您在块图中定义训练和推理操作的方式与您的代码之间存在一些歧义,但无论如何您都知道如何只能在训练模式下使用某些模块。然后你可以相应地设置这个变量。
推荐阅读
- javascript - Promise.all 找不到父 Promise 的 Resolve
- artifactory - 部署到 JFrog Artifactory 的工件的唯一链接
- javascript - 这个函数可以用正则表达式重写吗?
- python - 在同一项目中的 Django 和 django-rest-framework 之间共享经过身份验证的用户
- c# - C# 菜单条导致单击按钮时表单崩溃
- python-3.x - python字典获取键
- javascript - javascript字母字符查找器无限循环故障
- html - CSS 导航栏未向右浮动
- scrapy - Scrapy+Splash (osx) 的 GUI 和用户交互
- c# - 如何从 c# 中的 lambda 表达式列表构造 IEnumerable?