python - Pytorch:我们可以直接在 forward() 函数中使用 nn.Module 层吗?
问题描述
通常,在构造函数中,我们声明了我们要使用的所有层。在 forward 函数中,我们定义了模型将如何运行,从输入到输出。
我的问题是,如果直接在函数中调用那些预定义/内置的nn.Modules会怎样?forward()
Pytorch 的这种Keras函数 API样式是否合法?如果不是,为什么?
更新:以这种方式构建的TestModel确实运行成功,没有警报。但与传统方式相比,训练损失会缓慢下降。
import torch.nn as nn
from cnn import CNN
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51
def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x
解决方案
您需要考虑可训练参数的范围。
如果你在forward
模型的函数中定义了一个卷积层,那么这个“层”的范围和它的可训练参数是函数本地的,并且在每次调用该forward
方法后都会被丢弃。forward
您无法更新和训练每次通过后不断丢弃的权重。
但是,当 conv 层是您的成员时,model
它的范围超出了forward
方法,并且只要model
对象存在,可训练参数就会持续存在。通过这种方式,您可以更新和训练模型及其权重。
推荐阅读
- c - Socket数据发送正确但接收改变
- variables - 如何在将值列表附加到变量时启动make?
- php - 为什么我会收到此编译器错误。语法对我来说是正确的
- c# - Azure.Storage.Blobs 文件夹结构
- python - 从 Twitter 推文中删除 unicode 编码的表情符号
- json - 在JQ中的公共字段上合并数组元素
- javascript - 如何根据 Svelte 中另一家商店的价值稳健地更新一家商店?
- amazon-web-services - 在 Bitbucket 管道中使用不同的 aws 凭证
- java - 如何解决 web-app 必须在 Spring 中声明错误?
- reactjs - 错误:操作必须是普通对象。使用自定义中间件进行异步操作。如何解决?