首页 > 技术文章 > Variational Lower Bound的一个Extension

alexdeblog 2014-02-24 13:02 原文

在generative model中,一般通过maximum likelihood来学习模型参数。含有隐藏变量时就需要将隐藏变量marginalize out,用marginal likelihood \(p(x) = \sum_h p(x, h)\)。对于log-likelihood,以下这个变换式是一个很常见的结论

$$\begin{align*}\log \sum_h p(x,h) &= \log p(x)\\ &= \sum_h q(h) \log p(x) \\ &= \sum_h q(h) \log \frac{p(x,h)}{p(h|x)} \\ &= \sum_h q(h) \log \frac{p(x,h)}{q(h)}\frac{q(h)}{p(h|x)} \\ &= \sum_h q(h)\log p(x,h) - \sum_h q(h) \log q(h) + \sum_h q(h)\log \frac{q(h)}{p(h|x)} \\ &= \mathbb{E}_q[\log p(x,h)] + \mathcal{H}(q) + \mathrm{KL}(q(h)||p(h|x))\end{align*}$$

这里\(q\)是任意概率分布,\(\mathbb{E}_q\)是在\(q(h)\)下的期望,\(\mathcal{H}(q)\)是\(q\)的熵,KL代表KL-divergence。第二个等式是因为\(\sum_h q(h)=1\),第三个等式是因为\(p(x,h)=p(x)p(h|x)\)对任意h成立。

由于KL-divergence恒为非负,由上面这个式子就可以得到一个\(\log p(x)\)的下界(variational lower bound)

$$\log p(x) \ge \mathbb{E}_q[\log p(x,h)] + \mathcal{H}(q) = \sum_h q(h) \log p(x,h) - \sum_h q(h) \log q(h)$$

这个下界的好坏和\(q\)的选取有很大关系,选取一个不好的\(q\)可以使这个界非常松,而当\(q(h)=p(h|x)\)时,KL-divergence为零,左右两边相等。

 

对于含隐藏变量的模型来说\(p(x)\)往往难以计算,因为将\(h\) marginalize out的操作涉及到一个指数量级的求和(另一方面,若\(p(x)\)容易求则不需要隐藏变量模型了)。对于大部分含隐藏变量的模型,\(p(h|x)\)也难以计算("explaining away"),于是常常用一个更容易求的\(q(h)\)来做近似。例如mean-field inference就用一个factorial distribution \(q(h)=\prod_i q_i(h_i)\)来做近似,通过优化上面的下界可以找到所有factorial distribution中的最优近似。

 

今天要提到的是上面对于\(\log \sum_h p(x,h)\)的变换式其实可以generalize。有两点:(1)不需要x;(2)不需要p是一个概率分布(不需要归一化)。上面变换的核心是把log中对\(h\)的求和变到log外,更泛化的形式是对\(\log \sum_h \exp(f_x(h))\)的变换。对上面的讨论,\(f_x(h)=\log p(x,h)\)。下面的讨论中忽略\(x\),而考虑任意函数\(f(h)\):

$$\begin{align*}\log \sum_h \exp(f(h)) &= \log \sum_h q(h) \frac{\exp(f(h))}{q(h)}\\ &\ge \sum_h q(h)\log \frac{\exp(f(h))}{q(h)}\end{align*}$$

其中\(q(h)\)为任意概率分布,第二式用到了Jensen's inequality。左右两边的差值为

$$\begin{align*}\log \sum_h \exp(f(h)) - \sum_h q(h)\log \frac{\exp(f(h))}{q(h)} &= \sum_h q(h) \log \sum_{h'} \exp(f(h')) - \sum_h q(h)\log \frac{\exp(f(h))}{q(h)} \\ &= \sum_h q(h) \log q(h)\frac{\sum_{h'} \exp(f(h'))}{\exp(f(h))}\\ &= \sum_h q(h)\log \frac{q(h)}{\exp(f(h))/\sum_{h'} \exp(f(h'))}\end{align*}$$

若定义\(p^*(h)=\frac{\exp(f(h))}{\sum_{h'} \exp(f(h'))}\),则\(p^*\)显然是一个概率分布,而上式即为\(\mathrm{KL}(q||p^*)\)。

 

因此,我们有变换式

$$\log \sum_h \exp(f(h)) = \sum_h q(h) \log \frac{\exp(f(h))}{q(h)} + \mathrm{KL}(p||q^*) = \mathbb{E}_q[f(h)] + \mathcal{H}(q) + \mathrm{KL}(q||p^*)$$

将\(f(h)=\log p(x,h)\)带入,即得到\(p^*(h)=p(h|x)\),且上面的变换式就是我们之前得到的变换式。由此,任意\(\log \sum_h \exp(f(h))\)形式的式子都可以有一个variational lower bound,也可以用一系列相应的方法来进行优化了。

推荐阅读