python - 如何在 Keras 中连接两个 LSTM 模型
问题描述
我想用 Keras 创建一个有两个 LSTM 层的模型。但是,以下代码会生成错误:
from keras.models import Sequential
from keras.layers import LSTM, Dropout, Activation
from keras.callbacks import ModelCheckpoint
from keras.utils import to_categorical
model = Sequential()
model.add(LSTM(5, activation="softmax"))
model.add(LSTM(5, activation="softmax"))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['categorical_accuracy'])
# These values are to be predicted.
directions = [-2, -1, 0, 1, 2]
# Sample data. We have three time steps, one
# feature per timestep, and one resulting value.
data = [[[[1], [2], [3]], -1],
[[[3], [2], [1]], 2],
[[[4], [5], [7]], 1],
[[[1], [-1], [10]], -2]]
X = []
y_ = []
# Now we take 10000 samples from the data above.
for i in np.random.choice(len(data), 10000):
X.append(data[i][0])
y_.append(data[i][1])
X = np.array(X)
y_ = np.array(y_)
y = to_categorical(y_ + 2, num_classes=5)
model.fit(X, y,
epochs=3,
validation_data=(X, y))
print(model.summary())
loss, acc = model.evaluate(X, y)
print("Loss: {:.2f}".format(loss))
print("Accuracy: {:.2f}%".format(acc*100))
我收到以下错误:
ValueError: Input 0 is incompatible with layer lstm_10: expected ndim=3, found ndim=2
完整的错误回溯:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-35-58fa9218c3f3> in <module>
31 model.fit(X, y,
32 epochs=3,
---> 33 validation_data=(X, y))
34 print(model.summary())
35
C:\Anaconda3\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False
C:\Anaconda3\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
675 # to match the value shapes.
676 if not self.inputs:
--> 677 self._set_inputs(x)
678
679 if y is not None:
C:\Anaconda3\lib\site-packages\keras\engine\training.py in _set_inputs(self, inputs, outputs, training)
587 assert len(inputs) == 1
588 inputs = inputs[0]
--> 589 self.build(input_shape=(None,) + inputs.shape[1:])
590 return
591
C:\Anaconda3\lib\site-packages\keras\engine\sequential.py in build(self, input_shape)
219 self.inputs = [x]
220 for layer in self._layers:
--> 221 x = layer(x)
222 self.outputs = [x]
223 self._build_input_shape = input_shape
C:\Anaconda3\lib\site-packages\keras\layers\recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
530
531 if initial_state is None and constants is None:
--> 532 return super(RNN, self).__call__(inputs, **kwargs)
533
534 # If any of `initial_state` or `constants` are specified and are Keras
C:\Anaconda3\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
412 # Raise exceptions in case the input is not compatible
413 # with the input_spec specified in the layer constructor.
--> 414 self.assert_input_compatibility(inputs)
415
416 # Collect input shapes to build layer.
C:\Anaconda3\lib\site-packages\keras\engine\base_layer.py in assert_input_compatibility(self, inputs)
309 self.name + ': expected ndim=' +
310 str(spec.ndim) + ', found ndim=' +
--> 311 str(K.ndim(x)))
312 if spec.max_ndim is not None:
313 ndim = K.ndim(x)
ValueError: Input 0 is incompatible with layer lstm_10: expected ndim=3, found ndim=2
似乎第一个 LSTM 层的输出维度(假设 dim=2)与第二个 LSTM 层所需的输入维度(对于批处理、时间步长、特征而言,dim=3)不匹配。
让我烦恼的是,以我的方式将 LSTM 层添加在一起似乎在这里工作,例如:https ://adventuresinmachinelearning.com/keras-lstm-tutorial/
当我删除第二个 LSTM 层时,该模型有效。
解决方案
默认情况下,LSTM 仅在序列的最后一个元素之后返回它的最终输出。如果要将两个链接在一起,则需要在序列的每个元素之后将输出从第一个 LSTM 传递到第二个。例如
model = Sequential()
model.add(LSTM(5, return_sequences=True))
model.add(LSTM(5, activation="softmax"))
有关 return_sequence 如何工作的详细信息,请参阅文档https://keras.io/layers/recurrent/
推荐阅读
- c# - WPF 组合框中的选定项,与字典绑定,显示完整对象,而不仅仅是值
- regex - 正则表达式仅提取 .com 域?
- python - 遍历决策树并捕获每个节点
- python - 动态 URL 和 Jinja 模板
- vim - 删除当前行和前面的 n-1 行
- openmdao - openMDAO:优化在 1 次迭代后成功终止,而不是在最佳点
- ruby - What's the difference between binding.pry and Pry.start?
- pandas - 你如何只去掉熊猫中一列的整数?
- c# - 如何在 Maya 中实现 MObject 选择器
- java - 如何在java中输入两个名称之间有空格