tensorflow - 运行 LSTM 模型时出错,损失:NaN 值
问题描述
我使用 Keras 和 Tensorflow 的 LSTM 模型给出了loss: nan
值。
我试图降低学习率,但仍然得到 nan 并降低整体准确性,并且还用于np.any(np.isnan(x_train))
检查我可能介绍自己的 nan 值(没有找到 nan)。我还阅读了有关爆炸梯度的信息,但似乎找不到任何可以帮助解决我的具体问题的信息。
我想我知道问题可能出在哪里,但不太确定。这是我实施的构建过程x_train
例如:
a = [[1,0,..0], [0,1,..0], [0,0,..1]]
a.shape() # (3, 20)
b = [[0,0,..1], [0,1,..0], [1,0,..0], [0,1,..0]]
b.shape() # (4, 20)
为了确保形状相同,我将一个向量[0,0,..0]
(全为零)附加到a
现在的形状(4,20)
。
a
并附b
加以给出 3D 数组形状(2,4,20)
,并形成x_train
. 但我认为附加 0 的空向量是出于某种原因让我有一段时间loss: nan
训练我的模型。这是我可能出错的地方吗?
nba+b
是一个 numpy 数组,我的实际x_train.shape
是(1228, 1452, 20)
•编辑•model.summary()
添加如下:
x_train shape: (1228, 1452, 20)
y_train shape: (1228, 1452, 8)
x_val shape: (223, 1452, 20)
x_val shape: (223, 1452, 8)
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
unified_lstm (UnifiedLSTM) (None, 1452, 128) 76288
_________________________________________________________________
batch_normalization_v2 (Batc (None, 1452, 128) 512
_________________________________________________________________
unified_lstm_1 (UnifiedLSTM) (None, 1452, 128) 131584
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 1452, 128) 512
_________________________________________________________________
dense (Dense) (None, 1452, 32) 4128
_________________________________________________________________
dense_1 (Dense) (None, 1452, 8) 264
=================================================================
Total params: 213,288
Trainable params: 212,776
Non-trainable params: 512
解决方案
解决方案是使用Masking()
keras 中可用的层和mask_value=0
. 这是因为当使用空向量时,它们被计算到损失中,通过使用Masking()
,正如 keras 所概述的那样,填充向量被跳过而不包括在内。
根据 keras 文档:
'如果给定样本时间步的所有特征都等于 mask_value,则样本时间步将在所有下游层中被屏蔽(跳过)(只要它们支持屏蔽) '
推荐阅读
- database - 如何在 DigitalOcean droplet 上安装和启动 Exasol 社区?
- javascript - 检查 Laravel-blade 中所有复选框的好方法
- mysql - Python MySQL连接器插入,但信息实际上不在数据库中
- angular - CKEditor4 w/Angular8:工具栏中缺少按钮
- python - 与等待并行运行方法
- javascript - 承诺控制流
- reactjs - 当我尝试导入 DropDown react-native 时出错
- javascript - 如何设置 Firebase 规则以仅允许我的网站将数据读/写到 Firebase 存储?
- c# - CSharp 和 CSharp-NetCore 生成器之间的 OpenAPI 区别
- bash - 更改 .bashrc 和 .bash_profile 后终端无法正确更新