首页 > 解决方案 > 我可以在 nn.Sequential 中解压一个 nn.ModuleList 吗?

问题描述

我正在使用 参数化简单 ANN 的隐藏层数nn.ModuleList。我想知道是否将此列表nn.Sequential按如下方式传递到模块中会导致执行图产生任何不利影响。

nn.Sequential不是必需的,但是对我来说,在构造函数中明确显示整个架构似乎更干净。

class ANN(nn.Module):

    def __init__(
        self,
        in_feats=3,
        in_hidden=5,
        n_hidden_layers=3,

    ):
        super(ANN, self).__init__()

        # ====== dynamically register hidden layers ======
        self.hidden_linears = nn.ModuleList()
        for i in range(n_hidden_layers):
            self.hidden_linears.append(nn.Linear(in_hidden, in_hidden))
            self.hidden_linears.append(nn.ReLU())
        
        # ====== sequence of layers ======
        self.layers = nn.Sequential(
            nn.Linear(in_feats, in_hidden),
            nn.ReLU(),
            *self.hidden_linears,
            nn.Linear(in_hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, X):
        return self.layers(X)

也愿意接受有关将其组合在一起的更简洁方法的建议。

标签: pythonneural-networkpytorch

解决方案


推荐阅读