pytorch - 可以在 Pytorch nn.Sequential() 中添加条件
问题描述
有没有办法在nn.Sequential()
. 类似于下面的代码。
import torch
class Building_Blocks(torch.nn.Module):
def conv_block (self, in_features, out_features, kernal_size, upsample=False):
block = torch.nn.Sequential(
torch.nn.Conv2d(in_features, out_features, kernal_size),
torch.nn.ReLU(inplace = True),
torch.nn.Conv2d(out_features, out_features, kernal_size),
torch.nn.ReLU(inplace = True),
if(upsample):
torch.nn.ConvTranspose2d(out_features, out_features, kernal_size)
)
return block
def __init__(self):
super(Building_Blocks, self).__init__()
self.contracting_layer1 = self.conv_block(3, 64, 3, upsample=True)
def forward(self, x):
x=self.contracting_layer1(x)
return x
解决方案
不,但是在您的情况下,很容易if
摆脱nn.Sequential
:
class Building_Blocks(torch.nn.Module):
def conv_block(self, in_features, out_features, kernal_size, upsample=False):
layers = [
torch.nn.Conv2d(in_features, out_features, kernal_size),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_features, out_features, kernal_size),
torch.nn.ReLU(inplace=True),
]
if upsample:
layers.append(
torch.nn.ConvTranspose2d(out_features, out_features, kernal_size)
)
block = torch.nn.Sequential(*layers)
return block
def __init__(self):
super(Building_Blocks, self).__init__()
self.contracting_layer1 = self.conv_block(3, 64, 3, upsample=True)
def forward(self, x):
x = self.contracting_layer1(x)
return x
您始终可以根据需要构建list
包含层,然后将其解压缩torch.nn.Sequential
。
推荐阅读
- c# - 如何在另一个场景中引用对象的组件?
- python - AWS lambda UnicodeDecodeError
- swift - 使用 typealias 从回调转移到委托
- kubernetes - Kubernetes 服务将外部 ip 设置为待处理
- javascript - 有没有办法计算 discord.js 中提及的字符
- javascript - Html 联系表单按钮处于非活动状态
- xamarin - 如何在 Xamarin 表单中针对不同环境在 android 和 iOS 项目上配置 APP CENTER 应用程序密钥?
- python - 如何在 Pyspark 中使用 groupby 删除列
- javascript - 当我播放我的 3 首曲目时,它们相互重叠,我如何在播放 track2 时暂停 track1?使用 javascript
- python - 如何在 tf.data.Dataset.map 中使用 sklearn.preprocessing?