python - 在 Pytorch 中测试 LSTM 的实现
问题描述
我正在尝试在这里使用 LSTM 的 Pytorch 实现。我把它包括在这里供参考。它由两个类组成,LSTMCell 和 LSTM,其中 LSTMCell 只是一个单元,LSTM 将多个单元堆叠在一起以创建完整的 LSTM 模型
import math
import torch as th
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x, hidden):
if hidden is None:
hidden = self._init_hidden(x)
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
preact = self.i2h(x) + self.h2h(h)
# activations
gates = preact[:, :3 * self.hidden_size].sigmoid()
g_t = preact[:, 3 * self.hidden_size:].tanh()
i_t = gates[:, :self.hidden_size]
f_t = gates[:, self.hidden_size:2 * self.hidden_size]
o_t = gates[:, -self.hidden_size:]
c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
h_t = th.mul(o_t, c_t.tanh())
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
return h_t, (h_t, c_t)
@staticmethod
def _init_hidden(input_):
h = th.zeros_like(input_.view(1, input_.size(1), -1))
c = th.zeros_like(input_.view(1, input_.size(1), -1))
return h, c
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super().__init__()
self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
def forward(self, input_, hidden=None):
# input_ is of dimensionalty (1, time, input_size, ...)
outputs = []
for x in torch.unbind(input_, dim=1):
hidden = self.lstm_cell(x, hidden)
outputs.append(hidden[0].clone())
return torch.stack(outputs, dim=1)
我正在做以下简单的测试:
x = torch.randn(1, 3, 2, 4)
model = LSTM(4, 5, False)
model(x)
我收到以下错误。这里到底有什么问题?
TypeError Traceback (most recent call last)
<ipython-input-33-09e5544a61fc> in <module>
----> 1 model = LSTM(4, 5, False)
<ipython-input-30-9ad06cd4b768> in __init__(self, input_size, hidden_size, bias)
3 def __init__(self, input_size, hidden_size, bias=True):
4 super().__init__()
----> 5 self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
6
7 def forward(self, input_, hidden=None):
<ipython-input-29-c91ddfb9dfae> in __init__(self, input_size, hidden_size, bias)
6
7 def __init__(self, input_size, hidden_size, bias=True):
----> 8 super(LSTM, self).__init__()
9 self.input_size = input_size
10 self.hidden_size = hidden_size
TypeError: super(type, obj): obj must be an instance or subtype of type
解决方案
第一个参数super()
应该是类本身,而不是不同的类。
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTM, self).__init__()
# ^^^^ self is not an instance of LSTM but LSTMCell
它应该是:
super(LSTMCell, self).__init__()
从 Python 3 开始,您可以省略 super 的参数以获得相同的结果(就像您在LSTM
课堂上所做的那样):
super().__init__()
推荐阅读
- heroku - Heroku 和 axios
- google-maps-api-3 - 如何通过再次单击标记关闭打开的谷歌地图信息窗口
- python - 如何在 Matplotlib 中获取散点图的最小(或最大)边界
- c# - 使用 GetType().GetTypeInfo().GetDeclaredProperty 为现有属性设置值时 C# 参考错误
- python - Pandas 数据框和字典的深拷贝
- socket.io - 使用分子运行程序而不是 ServiceBroker 将分子 io 与分子网络集成的示例?
- c++ - 为什么每次执行时函数的地址都不同?
- vba - 我使用这段代码在 Access 2013 (ADO VBA) 中收到“错误 3251”检索表键
- apache-spark - 在 spark sql 中将字符串类型转换为数组类型
- data-visualization - 在 tableau 中创建日期计算字段