python - pytorch 前向检查的输入尺寸错误
问题描述
我正在创建一个RNN,pytorch
它看起来像这样:
class MyRNN(nn.Module):
def __init__(self, batch_size, n_inputs, n_neurons, n_outputs):
super(MyRNN, self).__init__()
self.n_neurons = n_neurons
self.batch_size = batch_size
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons)
self.FC = nn.Linear(self.n_neurons, self.n_outputs)
def init_hidden(self, ):
# (num_layers, batch_size, n_neurons)
return torch.zeros(1, self.batch_size, self.n_neurons)
def forward(self, X):
self.batch_size = X.size(0)
self.hidden = self.init_hidden()
lstm_out, self.hidden = self.basic_rnn(X, self.hidden)
out = self.FC(self.hidden)
return out.view(-1, self.n_outputs)
我的输入x
如下所示:
tensor([[-1.0173e-04, -1.5003e-04, -1.0218e-04, -7.4541e-05, -2.2869e-05,
-7.7171e-02, -4.4630e-03, -5.0750e-05, -1.7911e-04, -2.8082e-04,
-9.2992e-06, -1.5608e-05, -3.5471e-05, -4.9127e-05, -3.2883e-01],
[-1.1193e-04, -1.6928e-04, -1.0218e-04, -7.4541e-05, -2.2869e-05,
-7.7171e-02, -4.4630e-03, -5.0750e-05, -1.7911e-04, -2.8082e-04,
-9.2992e-06, -1.5608e-05, -3.5471e-05, -4.9127e-05, -3.2883e-01],
...
[-6.9490e-05, -8.9197e-05, -1.0218e-04, -7.4541e-05, -2.2869e-05,
-7.7171e-02, -4.4630e-03, -5.0750e-05, -1.7911e-04, -2.8082e-04,
-9.2992e-06, -1.5608e-05, -3.5471e-05, -4.9127e-05, -3.2883e-01]],
dtype=torch.float64)
是一组大小为 15 的 64 个向量。
尝试通过以下方式测试此模型时:
BATCH_SIZE = 64
N_INPUTS = 15
N_NEURONS = 150
N_OUTPUTS = 1
model = MyRNN(BATCH_SIZE, N_INPUTS, N_NEURONS, N_OUTPUTS)
model(x)
我收到以下错误:
File "/home/tt/anaconda3/envs/venv/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 126, in check_forward_args
expected_input_dim, input.dim()))
RuntimeError: input must have 3 dimensions, got 2
我该如何解决?
解决方案
您缺少 RNN 层所需的维度之一。
根据文档,您的输入大小需要具有一定的形状(序列长度、批次、输入大小)。
所以 - 对于上面的示例,您缺少其中之一。根据您的变量名称,您似乎正在尝试传递 64 个示例,每个示例包含 15 个输入……如果这是真的,那么您缺少序列长度。
对于 RNN,序列长度是您希望层重复出现的次数。例如,在 NLP 中,您的序列长度可能等于一个句子中的单词数,而批量大小将是您传递的句子数,输入大小将是每个单词的向量大小。
如果你只是想使用 64 个大小为 15 的样本,你可能不需要 RNN。
推荐阅读
- elasticsearch - 如何激活和配置 ElasticSearch Kafka Connect 接收器?
- html - Angular 2 Form 帮助...后端功能
- javascript - 如何在创建实际的 SVG 元素(在 Angular 中)之前测量 SVG 中字符串的宽度
- r - R ggplot:如何使用边缘箱形图创建散点图
- javascript - 按属性内对象中的自定义顺序排序
- java - 是否可以在单元测试中绕过某些异常?
- java - 添加 JLabel 和 JButton 后看不到 JTextField
- java - 在 Spring Boot / FreeMarker 中加载图像时出现问题
- javascript - Javascript不调用函数
- optimization - 优化 applescript 以减少能源影响