python - 关于 Keras LSTM 的输出
问题描述
我使用 Keras 构建了一个 LSTM 架构。我的目标是将长度为 29 的浮点时间序列输入序列映射到长度为 29 的浮点输出序列。我正在尝试实施“多对多”的方法。我按照这篇文章实现了这样的模型。
我首先将每个数据点重塑np.array
为形状为 `(1, 29, 1) 的形状。我有多个数据点,并分别在每个数据点上训练模型。以下代码是我构建模型的方式:
def build_model():
# define model
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(29, return_sequences=True, input_shape=(29, 1)))
model.add(tf.keras.layers.LeakyReLU(alpha=0.3))
model.compile(optimizer='sgd', loss='mse', metrics = ['mae'])
#cast data
for point in train_dict:
train_data = train_dict[point]
train_dataset = tf.data.Dataset.from_tensor_slices((
tf.cast(train_data[0], features_type),
tf.cast(train_data[1], target_type))
).repeat() #cast into X, Y
# fit model
model.fit(train_dataset, epochs=100,steps_per_epoch = 1,verbose=0)
print(model.summary())
return model
我很困惑,因为当我调用model.predict(test_point, steps = 1, verbose = 1)
模型时返回 29 长度 29 序列!根据我对链接帖子的理解,我不明白为什么会发生这种情况。当我尝试return_state=True
而不是return_sequences=True
then 我的代码会引发此错误:ValueError: All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.
我该如何解决这个问题?
解决方案
您的模型几乎没有缺陷。
模型的最后一层是 LSTM。假设您正在进行分类/回归。这之后应该是一个密集层(SoftMax/sigmoid - 分类,线性 - 回归)。但由于这是一个时间序列问题,因此应将密集层包装在 TimeDistributed 包装器中。
在 LSTM 之上应用 LeakyReLU 很奇怪。
我已经修复了上述问题的代码。看看是否有帮助。
from tensorflow.keras.layers import Embedding, Input, Bidirectional, LSTM, Dense, Concatenate, LeakyReLU, TimeDistributed
from tensorflow.keras.initializers import Constant
from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential
def build_model():
# define model
model = Sequential()
model.add(LSTM(29, return_sequences=True, input_shape=(29, 1)))
model.add(TimeDistributed(Dense(1)))
model.compile(optimizer='sgd', loss='mse', metrics = ['mae'])
print(model.summary())
return model
model = build_model()
推荐阅读
- python - openpyxl 无法读取严格的 Open XML 电子表格格式:用户警告:文件包含 Sheet1 的无效规范。这将被删除
- r - 设置颜色以显示清晰的数字
- javascript - 使用 HtmlElementView Widget 时如何触发函数?
- java - 使用 Oauth 2 对服务器进行身份验证/自动化的最佳方法是什么?
- mysql - 查找行的百分比mysql查询
- uwp - Hololens 2 上的 MediaPlayer 和 MediaStreamSource
- php - MySQL 仅在字段为空或 NULL 时更新字段
- react-native - 嵌套屏幕上“ModalPresentationIOS”类型的反应导航器模式具有不稳定的行为
- pandas - 如何根据单列值动态更新熊猫中的两列?
- python - 如何在python中将具有多个部分的段落转换为json?