pytorch - pytorch 中的 torch.nn.gru 函数的输入是什么?
问题描述
我正在使用 gru 函数来实现 RNN。这个 RNN (GRU) 在一些 CNN 层之后使用。有人可以告诉我这里 GRU 函数的输入是什么吗?特别是,隐藏的大小是固定的吗?
self.gru = torch.nn.GRU(
input_size=input_size,
hidden_size=128,
num_layers=1,
batch_first=True,
bidirectional=True)
根据我的理解,输入大小将是特征的数量,而 GRU 的隐藏大小总是固定为 128?有人可以纠正我。或提供他们的反馈
解决方案
首先,GRU
它不是一个函数而是一个类,你正在调用它的构造函数。您在GRU
这里创建了一个类的实例,它是一个层(或Module
在 pytorch 中)。
input_size
必须与前out_channels
一个 CNN 层相匹配。
您看到的所有参数都不是固定的。只要把另一个值放在那里,它就会是别的东西,即用你喜欢的任何东西替换 128。
即使它被称为hidden_size
,对于 GRU,此参数也决定了输出特征。换句话说,如果您在 GRU 之后还有另一层,则该层input_size
(或in_features
或in_channels
或其他任何名称)必须与 GRU 匹配hidden_size
。
另外,请查看文档。这会准确地告诉您传递给构造函数的参数的用途。此外,它会告诉您在实际使用层(通过self.gru(...)
)后预期的输入是什么,以及该调用的输出是什么。
推荐阅读
- python - python zipfile无法在request.raw中打开类似文件的对象流
- c++ - 如何选择一个不包括先前选择的随机数?
- javascript - 包含页眉和页脚不适用于 ejs
- mysql - 使用 concat 更改格式数据并替换 mysql
- javascript - 无错误或运行时错误的多个 DOM 遍历
- html - 普通文本字段可行,但一旦我添加表单控件,它就会停止工作
- python-multithreading - 多次启动线程与循环flask-socketIO
- mysql - 不允许在(子)分区函数中使用 MySQL 常量、随机或时区相关的表达式
- express - 如何更改 Electron 中的响应发送文件路径
- flutter - 有没有办法在 Flutter 中使图像文件的白色背景完全透明?