python - 使用 Keras 进行 LSTM 优化
问题描述
更新:感谢 Catalina 抽出宝贵时间
我按照您的建议做了 1- 我将数据拆分为训练和验证,并将其添加到合适的位置
history = model_final.fit(x_train_multi,
y_train_multi,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(x_val_multi, y_val_multi),
callbacks=[tensorboard_callback,model_checkpoint_callback]
)
我还创建了检查点并添加到回调中
checkpoint_filepath = 'checkpoint_model/'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='mse',
mode='min',
save_best_only=True)
3-我还尝试添加 dropout,添加更多 LSTM 层并增加每层的单元数
________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_50 (LSTM) (None, 128) 67072
_________________________________________________________________
activation_19 (Activation) (None, 128) 0
_________________________________________________________________
dense_48 (Dense) (None, 256) 33024
_________________________________________________________________
dense_49 (Dense) (None, 7) 1799
=================================================================
val_mse: 0.036
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_51 (LSTM) (None, 128) 67072
_________________________________________________________________
batch_normalization_11 (Batc (None, 128) 512
_________________________________________________________________
activation_20 (Activation) (None, 128) 0
_________________________________________________________________
dense_50 (Dense) (None, 256) 33024
_________________________________________________________________
dense_51 (Dense) (None, 7) 1799
=================================================================
val_mse: 0.91
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_54 (LSTM) (None, 128) 67072
_________________________________________________________________
activation_23 (Activation) (None, 128) 0
_________________________________________________________________
dropout_17 (Dropout) (None, 128) 0
_________________________________________________________________
dense_54 (Dense) (None, 256) 33024
_________________________________________________________________
dropout_18 (Dropout) (None, 256) 0
_________________________________________________________________
dense_55 (Dense) (None, 7) 1799
=================================================================
val_mse: 0.0281
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_55 (LSTM) (None, 30, 128) 67072
_________________________________________________________________
activation_24 (Activation) (None, 30, 128) 0
_________________________________________________________________
lstm_56 (LSTM) (None, 128) 131584
_________________________________________________________________
activation_25 (Activation) (None, 128) 0
_________________________________________________________________
dense_56 (Dense) (None, 256) 33024
_________________________________________________________________
dense_57 (Dense) (None, 7) 1799
=================================================================
val_mse: 0.0822
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_57 (LSTM) (None, 30, 128) 67072
_________________________________________________________________
activation_26 (Activation) (None, 30, 128) 0
_________________________________________________________________
lstm_58 (LSTM) (None, 128) 131584
_________________________________________________________________
activation_27 (Activation) (None, 128) 0
_________________________________________________________________
dense_58 (Dense) (None, 256) 33024
_________________________________________________________________
dense_59 (Dense) (None, 256) 65792
_________________________________________________________________
dense_60 (Dense) (None, 7) 1799
=================================================================
val_mse: 0.0541
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_61 (LSTM) (None, 30, 128) 67072
_________________________________________________________________
activation_30 (Activation) (None, 30, 128) 0
_________________________________________________________________
lstm_62 (LSTM) (None, 128) 131584
_________________________________________________________________
activation_31 (Activation) (None, 128) 0
_________________________________________________________________
dense_62 (Dense) (None, 7) 903
=================================================================
val_mse: 0.0713
4-我也尝试添加 BatchNormalization
似乎没有什么有助于准确预测未来 5 天。
您认为我可以尝试的其他任何事情吗?
谢谢!!
几周前我开始在 python 中使用 keras,现在我正在尝试解决预测问题。基本上我有活跃的学生,每天有多少学生协助上课。我想做的是预测接下来 X 天的援助,其中 x 现在是 10 天。为此,我有从 1 月 1 日到 9 月 28 日的信息我正在做的是将学生人数和每天协助的学生人数以 30 天为一组进行分组,我正在为 RNN 提供这些信息并获得作为输出一个包含 10 个值的数组
这是我的模型模型:“sequential_1”
层(类型)输出形状参数#
lstm_2 (LSTM)(无,256)265216
dense_1(密集)(无,128)32896
dense_2(密集)(无,10)1290
总参数:299,402 可训练参数:299,402 不可训练参数:0
我正在尝试改进我的模型以获得尽可能精确的预测,但我很难做到这一点。
我会很感激你的意见和建议
这是我的谷歌 colab 笔记本: https ://colab.research.google.com/drive/15RhjFwzjpUAhdXFWkn1dmRgFLoDQcuiU?usp=sharing
和数据集: https ://drive.google.com/file/d/1-HhaXSs4Bf0a0626bhAUCobFZRTasOGM/view?usp=sharing
提前致谢
解决方案
这里有一些提示:
将您的训练集拆分为训练/验证并将验证集添加到您的模型(您将其作为元组参数添加到 fit 函数)。
添加一个检查点和(回调,作为参数添加到 fit 函数),这将保存最佳模型,然后在进行预测时使用它。查看 Tensorflow/Keras 文档以获取更多信息。
测试 Dropout 和隐藏大小的不同值。
添加一个 BatchNormalization 层。
尝试所有这些,看看哪个提供最好的结果。
推荐阅读
- google-cloud-platform - 当我尝试在 Bizzflow.net 中使用 Google 表格提取器时,出现错误请求超时。如何解决?
- c# - 在 C# 中发布后 Oracle 连接请求超时
- methods - 在 SwiftUI 中使用结构而非方法生成视图的好处
- ubuntu - 我想找到具有大型持久存储的 Ubuntu live USB 方式
- python - python中operator.eq是如何实现的?
- karma-jasmine - 单元测试错误:无法读取 null 的属性“控件”
- android - 在文本视图上向左滑动显示删除选项不在 recyclerview android Kotlin
- javascript - 如何在我的函数中嵌套 setInterval 和 setTimeout?
- postgresql - 如何在 PostgreSQL for Azure 上使用或安装 pgAgent
- sql - 从小时表更新汇总表