python - 了解我的 LSTM 模型的结构
问题描述
我正在尝试解决以下问题:
我有来自许多设备的时间序列数据。每个设备记录的长度为 3000。捕获的每个数据点都有 4 个测量值。所以我的数据是成形的(设备记录的数量,3000、4)。
我正在尝试生成一个长度为 3000 的向量,其中每个数据点是 3 个标签(y1、y2、y3)之一,所以我想要的输出暗淡是(设备记录数,3000、1)。我已经标记了用于训练的数据。
我正在尝试为此使用 LSTM 模型,因为“随着时间序列数据移动时的分类”似乎是 RNN 类型的问题。
我的网络设置如下:
model = Sequential()
model.add(LSTM(3, input_shape=(3000, 4), return_sequences=True))
model.add(LSTM(3, activation = 'softmax', return_sequences=True))
model.summary()
总结如下:
Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_29 (LSTM) (None, 3000, 3) 96
_________________________________________________________________
lstm_30 (LSTM) (None, 3000, 3) 84
=================================================================
Total params: 180
Trainable params: 180
Non-trainable params: 0
_________________________________________________________________
在输出空间中一切看起来都很好,因为我可以使用每个单元的结果来确定我的三个类别中的哪一个属于那个特定的时间步(我认为)。
但我只有 180 个可训练参数,所以我猜我做错了什么。
有人可以帮我理解为什么我的可训练参数这么少吗?我是否误解了如何设置这个 LSTM?我只是在担心什么吗?
那 3 个单位是否意味着我只有 3 个 LSTM“块”?并且它只能回顾 3 个观察结果?
解决方案
在一个简单的观点中,您可以将一个LSTM
层视为具有内存的增强Dense
层(因此可以有效地处理序列)。所以“单元”的概念对于两者来说也是相同的:这些层的神经元或特征单元的数量,或者换句话说,这些层可以从输入中提取的独特特征的数量。
因此,当您将层的单元数指定为 3 时LSTM
,或多或少意味着该层只能从输入时间步长中提取 3 个不同的特征(请注意,单元数与输入序列的长度无关,即LSTM
无论单元的数量或输入序列的长度是多少,该层都将处理整个输入序列)。
通常,这可能不是最理想的(不过,这实际上取决于您正在处理的特定问题和数据集的难度;即,对于您的问题/数据集,也许 3 个单位可能就足够了,您应该尝试找出答案)。因此,通常会为单元的数量选择更大的数字(常见选择:32、64、128、256),并且分类任务也被委托给Dense
位于顶部的专用层(或有时称为“softmax 层”)该模型。
例如,考虑到问题的描述,具有 3 个堆叠LSTM
层和Dense
顶部分类层的模型可能如下所示:
model = Sequential()
model.add(LSTM(64, return_sequences=True, input_shape=(3000, 4)))
model.add(LSTM(64, return_sequences=True))
model.add(LSTM(32, return_sequences=True))
model.add(Dense(3, activation = 'softmax'))
推荐阅读
- reactjs - MultiSelect 在值状态更改时不更新值(PrimeReact UI)
- python - 从 pyplot.plot 动作的循环中获取 ylim - 只有循环寄存器中的最后一个,其他一切都是 0、1
- javascript - 随着 Firebase RTDB 上的数据发生变化,实时更新 Html 表
- azure-devops - 错误找不到工作项类型产品待办事项项
- html - SVG 图像在 Safari 中未显示正确的字体
- google-cloud-platform - 谷歌云不允许我创建我的第一个具有管理员角色的项目
- java - Javax 验证 API。将父 bean 中的字段添加到验证消息
- npm-install - 卡在 npm install - M1 Mac
- javascript - Vuelidate根据服务器端响应显示错误消息?
- javascript - 在电子生产模式下将 SQLITE db 文件存储在哪里?