首页 > 解决方案 > Keras - LSTM 如何重塑输入值

问题描述

我正在将 LSTM 与 keras 一起使用,我的测试和训练工作正常,但是当我尝试输入不同的输入时,我收到错误无法将大小 20 重塑为(1,20,30)

这是我的 model.fit() 代码

PositiveOrNegativeLabel=np.array([[1]])
PositiveOrNegativeLabel=PositiveOrNegativeLabel.reshape(1,-1)
PositiveOrNegativeLabel.shape
inputBatch =inputBatch.reshape(1,24,30)
testBatch =testBatch.reshape(1,24,30)
model=Sequential()
model.add(LSTM(100,input_shape=(24,30)))
model.add(Dense(1,activation="relu"))
model.compile(loss='mean_absolute_error',optimizer='adam')
model.fit(inputBatch,PositiveOrNegativeLabel,batch_size=24,epochs=9,verbose=1)

我的输入是 firstSentence 数组,就像这样

[  174 11501   420  4242 12111     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0]

firstSentence 的形状是 (20,)

我用来拟合模型的输入类型

[[    0. 12184.   420.  4636.     0.  8840.     0.     0. 10499. 11508.
   7511.     0.  5468.  2879.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  6689.  2818. 12003.  6480.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.     0.  3045. 11087.  2710.     0.   494.  1087.   420.  4995.
  11516.  3637.  5842.     0.  9963.  7015. 11090.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  1287.   420.  4070. 11087.  7410. 12186.  2387. 12111.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  3395.  1087. 11904.  7232.  8840. 10115.  4494. 11516.  7441.
   8535. 12106.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.   494.     0.     0.  6541.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  8744. 11105.  1570.  5842.   174. 11266.  2929. 10438.  2879.
      0. 10936.  6330.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0. 11956.  5222.     0.     0. 12106.  6481.     0.  7093. 13756.
  12152.     0.     0.     0.     0. 10173.     0.  5173. 13756.  9371.
      0.  9956.     0.     0.  9716.     0.     0.     0.     0.     0.]
 [    0.  3395.  1087. 11904.  7232.  8840. 10115.  4494. 11516.  7441.
   8535. 12106.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.   420.  5842.  3058. 11875.  2879.  1087. 11105.  4995.  8840.
      0. 11100. 11875.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  5419.   420.  2250.  1299.  2151. 12111.  6454.     0. 11501.
   8094.  5842.   942.  7503.  7410.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.   420.  5842.  3058. 11875.  2879.  1087. 11105.  4995.  8840.
      0. 11100. 11875.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  3395.  1087. 11904.  7232.  8840. 10115.  4494. 11516.  7441.
   8535. 12106.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  1287.   420.  4070. 11087.  7410. 12186.  2387. 12111.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  3395.  1087. 11904.  7232.  8840. 10115.  4494. 11516.  7441.
   8535. 12106.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0. 11501.  1592. 10603. 11102.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.   174.  5842.  2387. 10453. 11090.     0.  7531. 11956.   450.
    420. 11516.  6693.  2624.  9963. 11992.  9322. 11090. 12106.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  7544.     0.  1709.   420. 10936.  5222.  5842. 10407.  6937.
  11329.  2937.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  1520.  1295.     0.  8396.  9322. 12715.     0.  5172.  7232.
  11266.     0. 11266.  2757.  4416. 12020. 12111.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  7544.     0.  1709.   420. 10936.  5222.  5842. 10407.  6937.
  11329.  2937.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.     0.  9191.  5952.     0.     0. 11516.  9413.  3081.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.     0.     0.     0.     0.     0. 11516.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  3395.  1087. 11904.  7232.  8840. 10115.  4494. 11516.  7441.
   8535. 12106.     0.     0.     0.     0.     0.     0.     0.     0.
      0.     0.     0.     0.     0.     0.     0.     0.     0.     0.]
 [    0.  9371. 10412.  2356.  5412. 11502.     0.  1087.   228.     0.
   2937. 11480. 10412.  5412.   420.  9435.  2937.   228.  1057.  9435.
  12111.     0.     0.     0.     0.     0.     0.     0.     0.     0.]]

我得到的价值错误是这样的

cannot reshape array of size 20 into shape (1,20,30)

我正在运行这段代码来预测它

predict=model.predict(firstSentence, batch_size=24, verbose=1, steps=None)
# make a prediction
ynew = model.predict_classes(firstSentence)
# show the inputs and predicted outputs
for i in range(len(predict)):
    print("X=%s, Predicted=%s" % (predict[i], ynew[i]))

标签: pythontensorflowkeras

解决方案


我认为您将batch_size与时间步数混淆了。Kerasinput_shape中的 LSTM 层应该是 ( num_timesteps, num_features)。根据使用的 inbatch_size推断。据我所知,您的训练数据有 24 个示例,每个示例有 30 个时间步长,每个时间步长都有 1 个特征。batch_sizefit

此外,您的代码只有一个标签,但据我所知,您有 24 个示例,因此需要 24 个标签。

import numpy as np 
from keras.models import Sequential
from keras.layers import LSTM, Dense

# Need same number of labels as examples
PositiveOrNegativeLabel = np.ones(shape = (24, 1))
# 24 examples, each with 30 timesteps, 1 feature at each timestep
inputBatch = result.reshape(24, 30, 1)

model=Sequential()

# input_shape is (num_timesteps, num_features)
model.add(LSTM(100, input_shape=(30, 1)))
model.add(Dense(1, activation="relu"))

model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit(inputBatch, PositiveOrNegativeLabel,
          batch_size=24, epochs=9, verbose=1)

您的测试句子只有一个示例,有 20 个时间步,每个时间步有 1 个特征。因此,您需要用零填充测试语句。

from keras.preprocessing.sequence import pad_sequences

# Need 30 timesteps
firstSentence = pad_sequences([firstSentence], maxlen = 30)[0]
firstSentence = firstSentence.reshape((1, 30, 1))

predict=model.predict(firstSentence, batch_size=1, verbose=1, steps=None)

# make a prediction
ynew = model.predict_classes(firstSentence)
# show the inputs and predicted outputs
for i in range(len(predict)):
    print("X=%s, Predicted=%s" % (predict[i], ynew[i]))

此代码有效,但我可能错误地解释了您的设置。因此,如果您的训练标签建议您只有 1 个训练示例,那么输入批次的形状应该是(1, 24, 30),并且您的测试序列也需要 24 个时间步长,每个时间步长有 30 个特征。

有关在 Keras 中塑造 LSTM 输入的更多信息,请参阅这篇文章


推荐阅读