gradient - gpflow SVGP的爆炸梯度
问题描述
在为大数据集优化具有泊松似然的 SVGP 时,我看到了我认为的爆炸梯度。几个时期后,我看到 ELBO 急剧下降,然后在摆脱之前取得的所有进展后非常缓慢地恢复。大约 21 次迭代对应于一个 Epoch。
这个尖峰(至少是第二个尖峰)导致了参数的完全变化(对于参数向量,我刚刚绘制了范数以查看变化):
我该如何处理?我的第一种方法是裁剪渐变,但这似乎需要挖掘 gpflow 代码。
我的设置:
通过自然梯度对变分参数进行训练,通过 ADAM 对其余参数进行训练,自然梯度 Gamma 的时间表缓慢(线性)增加。
对于我的设置,批处理和诱导点的大小尽可能大(均为 2^12,数据集由约 88k 样本组成)。我包括 1e-5 抖动并用 kmeans 初始化诱导点。
我使用组合内核,由 RBF、Matern52、周期性和线性内核的组合组成,共有 95 个特征(其中很多是由于单热编码),所有特征都是可学习的。长度尺度使用 gpflow.transforms 进行转换。
with gpflow.defer_build():
k1 = Matern52(input_dim=len(kernel_idxs["coords"]), active_dims=kernel_idxs["coords"], ARD=False)
k2 = Periodic(input_dim=len(kernel_idxs["wday"]), active_dims=kernel_idxs["wday"])
k3 = Linear(input_dim=len(kernel_idxs["onehot"]), active_dims=kernel_idxs["onehot"], ARD=True)
k4 = RBF(input_dim=len(kernel_idxs["rest"]), active_dims=kernel_idxs["rest"], ARD=True)
#
k1.lengthscales.transform = gpflow.transforms.Exp()
k2.lengthscales.transform = gpflow.transforms.Exp()
k3.variance.transform = gpflow.transforms.Exp()
k4.lengthscales.transform = gpflow.transforms.Exp()
m = gpflow.models.SVGP(X, Y, k1 + k2 + k3 + k4, gpflow.likelihoods.Poisson(), Z,
mean_function=gpflow.mean_functions.Constant(c=np.ones(1)),
minibatch_size=MB_SIZE, name=NAME)
m.mean_function.set_trainable(False)
m.compile()
更新:仅使用 ADAM 按照 Mark 的建议,我只使用 ADAM,这帮助我摆脱了突然的爆炸。但是,我仍然只使用 natgrad 的一个 epoch 进行初始化,这似乎节省了很多时间。
此外,变分参数的变化似乎不那么突然(至少就它们的规范而言)。我想他们现在收敛速度会慢一些,但至少它是稳定的。
解决方案
只是添加到上面 Mark 的回答中,在非共轭模型中使用 nat grads 时,可能需要进行一些调整才能获得最佳性能,并且不稳定性可能是一个问题。正如 Mark 指出的那样,提供可能更快收敛的大步长也可能导致参数最终进入参数空间的不良区域。当变分近似很好(即真实和近似后验接近)时,有充分的理由期望 nat grad 表现良好,但不幸的是,在一般情况下没有灵丹妙药。有关一些直觉,请参见https://arxiv.org/abs/1903.02984。
推荐阅读
- c - 为什么 strtok 偶尔会导致总线错误?
- java - REST API 如何在线程完成工作后发送 http 响应
- r - 将 tibble 转换为模型类
- c++ - 在单个文件上定义预处理器定义
- reactjs - Web app built in React TS shows white screen after new build
- python - I am trying to import a fast food order program and it keeps saying "food_choice" is not defined
- visual-studio-code - What are the hotkeys to generate a method to Stringify multiple variables?
- python - 除法功能的问题
- ubuntu - Unable to add "local" Repository to sources.list on Ubuntu
- python - 从用户输入中获取列表中每个元素的长度