python - 图神经网络中的梯度爆炸问题
问题描述
我有一个梯度爆炸问题,尝试了几天后我无法解决。我在 TensorFlow 中实现了一个自定义消息传递图神经网络,用于从图数据中预测连续值。每个图表与一个目标值相关联。图的每个节点由节点属性向量表示,节点之间的边由边属性向量表示。
在消息传递层内,节点属性以某种方式更新(例如,通过聚合其他节点/边缘属性),并返回这些更新的节点属性。
现在,我设法弄清楚我的代码中出现梯度问题的位置。我有下面的片段。
to_concat = [neighbors_mean, e]
z = K.concatenate(to_concat, axis=-1)
output = self.Net(z)
这里,neighbors_mean
是两个节点属性 之间的元素平均值vi
,vj
它形成了具有边缘属性的边缘e
。Net
是一个单层前馈网络。这样,训练损失在大约 30 个 epoch 后突然跳到 NaN,批量大小为 32。如果批量大小为 128,梯度在大约 200 个 epoch 后仍然会爆炸。
我发现,在这种情况下,渐变会因为 edge 属性而爆炸e
。如果我没有连接neighbors_mean
并e
只使用下面的代码,就不会有梯度爆炸。
output = self.Net(neighbors_mean)
e
我也可以通过如下的 sigmoid 函数发送来避免梯度爆炸。但这会降低性能(最终 MAE),因为其中的值e
被非线性映射到 0-1 范围。请注意,整流线性单元(ReLU) 而不是 sigmoid 不起作用。
to_concat = [neighbors_mean, tf.math.sigmoid(e)]
z = K.concatenate(to_concat, axis=-1)
output = self.Net(z)
只需提一下,它e
带有与两个相应节点之间的距离相关的单个值,并且该距离始终在 0.5-4 范围内。中没有大值或 NaN e
。
我有一个自定义的损失函数来训练这个模型,但是我发现这不是损失的问题(其他损失也导致了同样的问题)。下面是我的自定义损失函数。请注意,虽然这是一个单输出回归网络,但我的 NN 的最后一层有两个神经元,与预测的均值和 log(sigma) 有关。
def robust_loss(y_true, y_pred):
"""
Computes the robust loss between labels and predictions.
"""
mean, sigma = tf.split(y_pred, 2, axis=-1)
# tried limiting 'sigma' with sigma = tf.clip_by_value(sigma,-4,1.0) but the gradients still explode
loss = np.sqrt(2.0) * K.abs(mean - y_true) * K.exp(-sigma) + sigma
return K.mean(loss)
我基本上尝试了在线建议的所有内容以避免梯度爆炸。
- 应用渐变剪裁 - with
Adam(lr, clipnorm=1, clipvalue=5)
和 also withtf.clip_by_global_norm(gradients, 1.0)
- 我的目标变量总是按比例缩放
- 权重用
glorot_uniform
分布初始化 - 对权重应用正则化
- 尝试了更大的批量(直到 256,尽管在某些时候会发生延迟梯度爆炸)
- 尝试降低学习率
我在这里想念什么?我绝对知道它与连接有关e
。但是鉴于 0.5<e<4,为什么在这种情况下梯度会爆炸?这个功能e
对我很重要。我还能做些什么来避免模型中的数值溢出?
解决方案
看起来很棒,因为您已经遵循了大多数解决梯度爆炸问题的解决方案。以下是您可以尝试的所有解决方案的列表
避免梯度爆炸问题的解决方案
适当的权重初始化:根据使用的激活函数使用适当的权重初始化。
初始化 激活函数 他 ReLU 和变体 乐存 赛卢 格洛罗特 Softmax、Logistic、无、Tanh 重新设计你的神经网络:在神经网络中使用更少的层和/或使用更小的批量大小
选择非饱和激活函数:选择正确的激活函数并降低学习率
- ReLU
- 泄漏的 ReLU
- 随机泄漏 ReLU (RReLU)
- 参数泄漏 ReLU (PReLU)
- 指数线性单位 (ELU)
批量归一化:理想情况下,根据最适合您的数据集的方法,在每一层之前/之后使用批量归一化。
每层后论文参考
model = keras.models.Sequential([ keras.layers.Flatten(input_shape=[28, 28]), keras.layers.BatchNormalization(), keras.layers.Dense(300, activation="elu", kernel_initializer="he_normal"), keras.layers.BatchNormalization(), keras.layers.Dense(100, activation="elu", kernel_initializer="he_normal"), keras.layers.BatchNormalization(), keras.layers.Dense(10, activation="softmax") ])
在每一层之前
model = keras.models.Sequential([ keras.layers.Flatten(input_shape=[28, 28]), keras.layers.BatchNormalization(), keras.layers.Dense(300, kernel_initializer="he_normal", use_bias=False), keras.layers.BatchNormalization(), keras.layers.Activation("elu"), keras.layers.Dense(100, kernel_initializer="he_normal", use_bias=False), keras.layers.Activation("elu"), keras.layers.BatchNormalization(), keras.layers.Dense(10, activation="softmax") ])
渐变剪裁:好的默认值是 clipnorm=1.0 和 clipvalue=0.5
确保使用正确的优化器:由于您使用了 Adam 优化器,请检查其他优化器是否最适合您的情况。有关可用优化器的信息,请参阅此文档[SGD、RMSprop、Adam、Adadelta、Adagrad、Admax、Nadam、Ftrl]
随时间截断的反向传播:通常适用于 RNNS,请参阅此文档
使用 LSTM(RNN 的解决方案)
在层上使用权重正则化器:设置
kernel_regularizer
为 L1 或 L2。权重正则化器文档参考
有关更多信息,请参阅Aurélien编写的使用 scikit-learn、keras 和 tensorflow 进行机器学习的第11 章
推荐阅读
- ios - 如何与测试用户一起测试“使用 Apple 登录”?
- java - 如何将数据从一个表复制到另一个表,然后使用 Java 删除第一个表中的数据?
- xslt - 我们可以在调用 number() 函数是 xslt 时使用加号参数吗
- python - 电报editMessageMedia替代telepot
- html - 随机生成“id”属性值的缺点是什么?
- html - 自动播放嵌入的 YouTube 视频
- python - 在 Python 中比较两个列表时匹配字符串
- python - 按值将一列拆分为熊猫中的两列
- java - 在 Azure AD 身份验证 Spring 启动应用程序之后手动分配角色
- apache-spark - 将结果数据集减少为单个数据集