python - 为 LSTM 模型调用预测函数时出现有关输入形状的错误
问题描述
我已经安装了一个 lstm 模型。每个 x 和 y 变量有 100 个观测值。我用 80 个值来训练模型,用 20 个值来评估模型。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout
input_features=1
time_stamp=100
out_features=1
x_input=np.random.random((time_stamp,input_features))
y_input=np.random.random((time_stamp,out_features))
train_x=x_input[0:80,:]
test_x=x_input[80:,:]
train_y=y_input[0:80,:]
test_y=y_input[80:,:]
然后,在将数据输入 LSTM 函数之前,我根据需要对数据进行了重新整形。(例如:用于训练 x:(samples, timesteps, features)=(1,80,1))
dataXtrain = train_x.reshape(1,80, 1)
dataXtest = test_x.reshape(1,20,1)
dataYtrain = train_y.reshape(1,80,1)
dataYtest = test_y.reshape(1,20,1)
dataXtrain.shape
(1, 80, 1)
然后我能够使用以下代码行成功拟合模型:
model = Sequential()
model.add(LSTM(20,activation = 'relu', return_sequences = True,input_shape=(dataXtrain.shape[1],
dataXtrain.shape[2])))
model.add(Dense(1))
model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit(dataXtrain, dataYtrain, epochs=100, batch_size=10, verbose=1)
但是当我预测测试数据的模型时,我得到了这个错误。
y_pred = model.predict(dataXtest)
Error when checking input: expected lstm_input to have shape (80, 1) but got array with shape (20, 1)
谁能帮我弄清楚这里有什么问题?
谢谢
解决方案
似乎问题出在数据准备上。我认为你应该划分你的样本(而不是时间步长)来训练和测试数据,并且训练和测试样本的形状应该是一样的(None, time-steps, features)
。
由于您只有一个包含 100 个观测值(时间步长)的样本,因此您可以将数据划分为包含小尺寸时间步长序列的样本。例如:
n_samples = 20
input_features = 1
time_stamp = 100
out_features = 1
x_input = np.random.random((1, time_stamp, input_features))
y_input = np.random.random((1, time_stamp, out_features))
new_time_stamp = time_stamp//n_samples
x_input = x_input.reshape(n_samples, new_time_stamp, input_features)
y_input = y_input.reshape(n_samples, new_time_stamp, out_features)
dataXtrain = x_input[:16,...]
dataXtest = x_input[16:,...]
dataYtrain = y_input[:16,...]
dataYtest = y_input[16:,...]
或者,您可以收集更多数据样本,每个样本包含 100 个时间步长(这取决于您的应用程序和现有数据)。
推荐阅读
- html - ngStyle 上角 5 的动画
- passport.js - Composer REST 服务器自定义身份验证
- android - android.view.InflateException: Binary XML file line #0: Error inflating class
- javascript - Spring MVC 使用 ajax 发送和接收参数
- neural-network - 卷积神经网络的归一化方法
- java - 使用 JOOQ DSL API 在 VALUES() 上加入表
- latex - Latex:在主文件中定义时在子文件中使用变量
- mysql-workbench - 涉及具有不同计数的特定计数的 SQL 查询
- javascript - 在 react-boilerplate 项目中从“redux-form”导入“字段”后出现错误
- android - Firebase 使用情况分析