首页 > 解决方案 > 为 LSTM 模型调用预测函数时出现有关输入形状的错误

问题描述

我已经安装了一个 lstm 模型。每个 x 和 y 变量有 100 个观测值。我用 80 个值来训练模型,用 20 个值来评估模型。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout

input_features=1
time_stamp=100
out_features=1
x_input=np.random.random((time_stamp,input_features))
y_input=np.random.random((time_stamp,out_features))

train_x=x_input[0:80,:]
test_x=x_input[80:,:]

train_y=y_input[0:80,:]
test_y=y_input[80:,:]

然后,在将数据输入 LSTM 函数之前,我根据需要对数据进行了重新整形。(例如:用于训练 x:(samples, timesteps, features)=(1,80,1))

dataXtrain = train_x.reshape(1,80, 1)
dataXtest = test_x.reshape(1,20,1)
dataYtrain = train_y.reshape(1,80,1)
dataYtest = test_y.reshape(1,20,1)
dataXtrain.shape
(1, 80, 1)

然后我能够使用以下代码行成功拟合模型:

model = Sequential()
model.add(LSTM(20,activation = 'relu', return_sequences = True,input_shape=(dataXtrain.shape[1], 
dataXtrain.shape[2])))
model.add(Dense(1))
model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit(dataXtrain, dataYtrain, epochs=100, batch_size=10, verbose=1)

但是当我预测测试数据的模型时,我得到了这个错误。

y_pred = model.predict(dataXtest)
Error when checking input: expected lstm_input to have shape (80, 1) but got array with shape (20, 1)

谁能帮我弄清楚这里有什么问题?

谢谢

标签: pythonkerasdeep-learninglstm

解决方案


似乎问题出在数据准备上。我认为你应该划分你的样本(而不是时间步长)来训练和测试数据,并且训练和测试样本的形状应该是一样的(None, time-steps, features)

由于您只有一个包含 100 个观测值(时间步长)的样本,因此您可以将数据划分为包含小尺寸时间步长序列的样本。例如:

n_samples = 20
input_features = 1
time_stamp = 100
out_features = 1
x_input = np.random.random((1, time_stamp, input_features))
y_input = np.random.random((1, time_stamp, out_features))

new_time_stamp = time_stamp//n_samples
x_input = x_input.reshape(n_samples, new_time_stamp, input_features)
y_input = y_input.reshape(n_samples, new_time_stamp, out_features)

dataXtrain = x_input[:16,...]
dataXtest = x_input[16:,...]

dataYtrain = y_input[:16,...]
dataYtest = y_input[16:,...]

或者,您可以收集更多数据样本,每个样本包含 100 个时间步长(这取决于您的应用程序和现有数据)。

你也可以看看这个这个在 Keras 中使用 LSTM 进行全面解释的内容。


推荐阅读