pytorch - 减少torch.nn.lstm中每层的节点数
问题描述
有没有一种简单的方法可以将每层中的节点数量减少一个因子?我在文档页面上没有看到这个选项,也许我可以使用类似的功能而不是手动定义每一层?
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.2,
) # lstm
解决方案
我不知道,但从头开始编写它很简单:
def _constant_scale(initial: int, factor: int) -> int:
return initial//factor
class StackedLSTM(Module):
def __init__(self, input_size: int, hidden_sizes: list[int], *args, **kwargs):
super(StackedLSTM, self).__init__()
self.layers = ModuleList([LSTM(input_size=xs, hidden_size=hs, *args, **kwargs) for xs, hs in zip([input_size] + hidden_sizes, hidden_sizes)])
def forward(self, x: Tensor, hc: Optional[tuple[Tensor, Tensor]] = None) -> Tensor:
for layer in self.layers:
x, _ = layer(x, hc)
hc = None
return x
hidden_sizes = [_constant_scale(300, 2**i) for i in range(3)]
sltm = StackedLSTM(100, hidden_sizes)
x = torch.rand(10, 32, 100)
h = torch.rand(1, 32, 300)
c = torch.rand(1, 32, 300)
out = sltm(x, (h, c))
print(out.shape)
# torch.Size([10, 32, 75])
推荐阅读
- swift - Swift Firestore 搜索用户
- javascript - 如何通过子字符串数组过滤大量字符串?
- c# - Xamarin Forms 简单的按钮动画
- java - 如何将流式 json 数据作为键值对发送到 kafka 消费者
- python - 高效地查找和替换列表中的字符串 Python
- excel - 使用 VBA 在远程桌面上打开文件
- c# - Unity framework in .net core
- excel - 图像下载失败时挂起的 Excel VBA 代码的超时或错误处理
- mediawiki - 是否可以在 page1 上创建变量并在 page2 上使用该变量?(不使用扩展)
- java - 如何在 if-else 语句中使用用户输入,如果用户输入不是所要求的,我如何循环代码?