首页 > 解决方案 > Pytorch:我们可以直接在 forward() 函数中使用 nn.Module 层吗?

问题描述

通常,在构造函数中,我们声明了我们要使用的所有层。在 forward 函数中,我们定义了模型将如何运行,从输入到输出。

我的问题是,如果直接在函数中调用那些预定义/内置的nn.Modules会怎样?forward()Pytorch 的这种Keras函数 API样式是否合法?如果不是,为什么?

更新:以这种方式构建的TestModel确实运行成功,没有警报。但与传统方式相比,训练损失会缓慢下降。

import torch.nn as nn
from cnn import CNN

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_embeddings = 2020
        self.embedding_dim = 51

    def forward(self, input):
        x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
        # CNN is a customized class and nn.Module subclassed
        # we will ignore the arguments for its instantiation 
        x = CNN(...)(x)
        x = nn.ReLu()(x)
        x = nn.Dropout(p=0.2)(x)
        return output = x

标签: pythonconstructorpytorchconv-neural-networkforward

解决方案


您需要考虑可训练参数的范围

如果你在forward模型的函数中定义了一个卷积层,那么这个“层”的范围和它的可训练参数是函数本地的,并且在每次调用该forward方法后都会被丢弃。forward您无法更新和训练每次通过后不断丢弃的权重。
但是,当 conv 层是您的成员时,model它的范围超出了forward方法,并且只要model对象存在,可训练参数就会持续存在。通过这种方式,您可以更新和训练模型及其权重。


推荐阅读