python - 为什么在 LSTM 中添加 relu 激活后会得到 Nan?
问题描述
我有一个简单的 LSTM 网络,大致如下所示:
lstm_activation = tf.nn.relu
cells_fw = [LSTMCell(num_units=100, activation=lstm_activation),
LSTMCell(num_units=10, activation=lstm_activation)]
stacked_cells_fw = MultiRNNCell(cells_fw)
_, states = tf.nn.dynamic_rnn(cell=stacked_cells_fw,
inputs=embedding_layer,
sequence_length=features['length'],
dtype=tf.float32)
output_states = [s.h for s in states]
states = tf.concat(output_states, 1)
我的问题是。当我不使用激活 (activation=None) 或使用 tanh 时,一切正常,但是当我切换 relu 时,我不断收到“训练期间的 NaN 损失”,这是为什么呢?它是 100% 可重现的。
解决方案
当您使用relu activation function
inside 时lstm cell
,可以保证单元的所有输出以及单元状态都是严格>= 0
的。正因为如此,你的渐变变得非常大并且正在爆炸。例如,运行以下代码片段并观察输出是 never < 0
。
X = np.random.rand(4,3,2)
lstm_cell = tf.nn.rnn_cell.LSTMCell(5, activation=tf.nn.relu)
hidden_states, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=X, dtype=tf.float64)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(hidden_states))
推荐阅读
- c++ - NAPI 模块在 Windows 上出现 LNK2019 错误
- angular - (新鲜的角度项目)当我将模块导入独立的功能模块时,main.js 文件的大小会增加
- python - 在python django中获取前几个月的数据
- elasticsearch - ApacheManifoldCF elasticsearch 输出连接器版本兼容性
- sql-server - Report Builder 3.0 在参数中使用多个值不适用于 IN 语句
- javascript - 在 Nuxt.js 中使用 Vue-Meta 内联 Js
- zeep - zeep.exceptions.XMLSyntaxError:找到的根元素是 html
- security - 在 ASP.NET Core 中处理错误的礼品卡密钥/密码尝试的最佳策略
- python - 如何使用 python 访问电子表格
- r - 如何加入整洁的数据集并合并列