python - 了解 Flatten 在 Keras 中的作用并确定何时使用它
问题描述
我试图了解为时间序列预测开发的模型。它使用一个 Con1D 层和两个 LSTM 层,然后是一个密集层。我的问题是,它应该Flatten()
在 LSTM 和 Denser 层之间使用吗?在我看来,输出应该只有一个值,其形状为(None, 1)
,并且可以通过Flatten()
在 LSTM 和 Dense 层之间使用来实现。没有Flatten()
,输出形状将是(None, 30, 1)
。或者,我可以return_sequences=True
从第二个 LSTM 层中删除 ,我认为它与Flatten()
. 哪种方式更合适?它们会影响损失吗?这是模型。
model = tf.keras.models.Sequential([
tf.keras.layers.Conv1D(filters=32, kernel_size=3, strides=1, padding="causal", activation="relu", input_shape=(30 ,1)),
tf.keras.layers.LSTM(32, return_sequences=True),
tf.keras.layers.LSTM(32, return_sequences=True),
# tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1),
])
这是没有的模型摘要Flatten()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1d (Conv1D) (None, 30, 32) 128
_________________________________________________________________
lstm (LSTM) (None, 30, 32) 8320
_________________________________________________________________
lstm_1 (LSTM) (None, 30, 32) 8320
_________________________________________________________________
dense (Dense) (None, 30, 1) 33
=================================================================
Total params: 16,801
Trainable params: 16,801
Non-trainable params: 0
_________________________________________________________________
解决方案
好吧,这取决于您想要实现的目标。我试着给你一些提示,因为我不是 100% 清楚你想要得到什么。
如果您的 LSTM 使用return_sequences=True
,那么您将返回每个 LSTM 单元的输出,即每个时间戳的输出。如果您随后添加一个密集层,其中一个将添加到每个 LSTM 层的顶部。
如果您将 flatten 层与 一起使用return_sequences=True
,那么您基本上是在删除时间维度,就像(None, 30)
您的情况一样。然后,您可以添加密集层或任何您需要的层。
如果您设置return_sequences=False
,您只会在 LSTM 的最后获得输出(请注意,在任何情况下,由于 LSTM 功能,它基于之前时间戳发生的计算),并且输出将是形状(None, dim)
其中dim
等于您在 LSTM 中使用的隐藏单元的数量(即 32)。同样,在这里,您可以简单地添加一个带有一个隐藏单元的密集层,以获得您正在寻找的东西。
推荐阅读
- android - 添加片段标签并检索它
- git - Visual Studio Code 中的“Git:执行 git 失败”
- javascript - React 的 Virtual DOM 如何比 DOM 快?
- java - 方法不会覆盖或实现超类型中的方法 - Reactnative
- c++ - namespace::variable 的多重定义,即使使用 ifndef
- html - HTML5 标签在 Android 上未正确呈现
- android - 在 Oppo F1 中处理程序的 postDelayed() 延迟结束后,Runnable 未完全执行
- angular - Angular:如何通过 API 调用使用 JWT 身份验证
- java - 图像背景要删除并设置为透明图像
- java - 匹配输入中两个不同位置的正则表达式