首页 > 解决方案 > 关于 Keras LSTM 的输出

问题描述

我使用 Keras 构建了一个 LSTM 架构。我的目标是将长度为 29 的浮点时间序列输入序列映射到长度为 29 的浮点输出序列。我正在尝试实施“多对多”的方法。我按照这篇文章实现了这样的模型。

我首先将每个数据点重塑np.array为形状为 `(1, 29, 1) 的形状。我有多个数据点,并分别在每个数据点上训练模型。以下代码是我构建模型的方式:

def build_model():
    # define model
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.LSTM(29, return_sequences=True, input_shape=(29, 1)))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.3))

    model.compile(optimizer='sgd', loss='mse', metrics = ['mae'])

    #cast data
    for point in train_dict:
        train_data = train_dict[point]

        train_dataset = tf.data.Dataset.from_tensor_slices((
            tf.cast(train_data[0], features_type),
            tf.cast(train_data[1], target_type))
        ).repeat() #cast into X, Y

        # fit model


        model.fit(train_dataset, epochs=100,steps_per_epoch = 1,verbose=0)


        print(model.summary())   
    return model 

我很困惑,因为当我调用model.predict(test_point, steps = 1, verbose = 1)模型时返回 29 长度 29 序列!根据我对链接帖子的理解,我不明白为什么会发生这种情况。当我尝试return_state=True而不是return_sequences=Truethen 我的代码会引发此错误:ValueError: All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.

我该如何解决这个问题?

标签: pythontensorflowkerasneural-networklstm

解决方案


您的模型几乎没有缺陷。

  1. 模型的最后一层是 LSTM。假设您正在进行分类/回归。这之后应该是一个密集层(SoftMax/sigmoid - 分类,线性 - 回归)。但由于这是一个时间序列问题,因此应将密集层包装在 TimeDistributed 包装器中。

  2. 在 LSTM 之上应用 LeakyReLU 很奇怪。

我已经修复了上述问题的代码。看看是否有帮助。

from tensorflow.keras.layers import Embedding, Input, Bidirectional, LSTM, Dense, Concatenate, LeakyReLU, TimeDistributed
from tensorflow.keras.initializers import Constant
from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential
def build_model():
    # define model
    model = Sequential()
    model.add(LSTM(29, return_sequences=True, input_shape=(29, 1)))
    model.add(TimeDistributed(Dense(1)))
    model.compile(optimizer='sgd', loss='mse', metrics = ['mae'])


    print(model.summary())   
    return model 

model = build_model()

推荐阅读