变分推断公式推导
背景介绍
机器学习中的概率模型可分为频率派和贝叶斯派。频率派最终是求一个优化问题,而贝叶斯派则是求一个积分问题。
频率派
举几个例子:
线性回归
样本数据:{(xi,yi)}i=1N\{(x_i,y_i)\}_{i=1}^N{(xi,yi)}i=1N
-
模型:f(w)=wTxf(w)=w^Txf(w)=wTx
-
策略:损失函数:L(w)=∑i=1N∣∣wTxi−yi∣∣2L(w)=\sum_{i=1}^N||w^Tx_i-y_i||^2L(w)=∑i=1N∣∣wTxi−yi∣∣2,w^=argminwL(w)\hat{w}=\arg\min_wL(w)w^=argminwL(w) 这就是一个无约束优化问题。
-
算法:解法
- 解析解:线性回归问题形式比较简单,可直接由最小二乘法求出解析解:w∗=(XTX)−1XTYw^*=(X^TX)^{-1}X^TYw∗=(XTX)−1XTY
- 数值解:对于其他较为复杂的算法无法解析。有一些求数值解的方法,如梯度下降等。
SVM
- 模型:f(w)=sign(wTx+b)f(w)=sign(w^Tx+b)f(w)=sign(wTx+b)
- 策略:损失函数:min12wTws.t.yi(wTxi+b≥1)min\frac{1}{2}w^Tw\ \ \ \ s.t.\ y_i(w^Tx_i+b\ge 1)min21wTw s.t. yi(wTxi+b≥1)。 是一个有约束的凸优化问题。
- 算法:解法有 QP、拉格朗日对偶等。
EM算法
θ(t+1)=argmaxθ∫ZlogP(X,Z∣θ)P(Z∣X,θ(t))dZ\theta^{(t+1)}=\arg\max_{\theta}\int_Z\log P(X,Z|\theta)P(Z|X,\theta^{(t)})dZ θ(t+1)=argθmax∫ZlogP(X,Z∣θ)P(Z∣X,θ(t))dZ
EM算法也是通过迭代来求解最大对数似然的数值解。
贝叶斯派
为什么说贝叶斯派是求积分呢?我们先来看贝叶斯定理:
P(θ∣X)=P(X∣θ)P(θ)P(X)P(\theta|X)=\frac{P(X|\theta)P(\theta)}{P(X)} P(θ∣X)=P(X)P(X∣θ)P(θ)
贝叶斯推断,要求得后验 P(θ∣X)P(\theta|X)P(θ∣X) 。
贝叶斯决策。决策可以理解为就是做预测。即 XXX 为已知的 NNN 个样本数据。决策就是求:
P(x~∣X)=∫θP(x~∣X)dθ=∫θP(x~∣θ)P(θ∣X)dθP(\tilde{x}|X)=\int_\theta P(\tilde{x}|X)d\theta=\int_\theta P(\tilde{x}|\theta)P(\theta|X)d\theta P(x~∣X)=∫θP(x~∣X)dθ=∫θP(x~∣θ)P(θ∣X)dθ
在通过贝叶斯推断求得后验 P(θ∣X)P(\theta|X)P(θ∣X) 之后,就可以按照上式进行贝叶斯决策。而且上面这个式子也可以写成关于后验的期望的形式(期望就是求积分):
P(x~∣X)=Eθ∣X[P(x~∣θ)]P(\tilde{x}|X)=\mathbb{E}_{\theta|X}[P(\tilde{x}|\theta)] P(x~∣X)=Eθ∣X[P(x~∣θ)]
贝叶斯派的关键就是求得后验 P(θ∣X)P(\theta|X)P(θ∣X) ,即贝叶斯推断的过程。贝叶斯推断又可分为精确推断和近似推断:
- 精确推断
- 近似推断
- 确定性近似:变分推断(本文的主题)
- 随机近似:MCMC、MH、Gibbs
公式推导
符号含义:XXX 为观测数据,ZZZ 为隐变量和参数。注意这里参数 θ\thetaθ 也一同表示在 ZZZ 中了。
再强调一下我们的目的:求后验 P(Z∣X)P(Z|X)P(Z∣X) 。
下面的前几步与 EM 算法导出的做法类似,详见 EM算法公式推导 ,区别只是把参数 θ\thetaθ 合并到了 ZZZ 中,步骤这里就不一一说明了。
logP(X)=logP(X,Z)−logP(Z∣X)=logP(X,Z)q(Z)−logP(Z∣X)q(Z)=∫Zq(Z)logP(X,Z)q(Z)dZ−∫Zq(Z)logP(Z∣X)q(Z)dZ=ELBO+KL(q(Z)∣∣P(Z∣X))=L(q)+KL(q(Z)∣∣P(Z∣X))\begin{align} \log P(X)&=\log P(X,Z)-\log P(Z|X)\\ &=\log \frac{P(X,Z)}{q(Z)}-\log \frac{P(Z|X)}{q(Z)}\\ &=\int_Zq(Z)\log\frac{P(X,Z)}{q(Z)}dZ-\int_Zq(Z)\log \frac{P(Z|X)}{q(Z)}dZ\\ &=ELBO+KL(q(Z)||P(Z|X))\\ &=\mathcal{L}(q)+KL(q(Z)||P(Z|X)) \end{align} logP(X)=logP(X,Z)−logP(Z∣X)=logq(Z)P(X,Z)−logq(Z)P(Z∣X)=∫Zq(Z)logq(Z)P(X,Z)dZ−∫Zq(Z)logq(Z)P(Z∣X)dZ=ELBO+KL(q(Z)∣∣P(Z∣X))=L(q)+KL(q(Z)∣∣P(Z∣X))
经过一系列变形,得到 EBLO+KLEBLO+KLEBLO+KL 的形式,这里我们将 ELBOELBOELBO 记为 L(q)\mathcal{L}(q)L(q) ,就是所谓的变分。
我们是要求的是后验 P(Z∣X)P(Z|X)P(Z∣X) ,如果其与 q(Z)q(Z)q(Z) 的 KL 散度接近0,那么就能用 q(Z)q(Z)q(Z) 来对其进行近似。而等式左边 logP(X)\log P(X)logP(X) 与 ZZZ 无关,因此 ELBO+KLELBO+KLELBO+KL 在 q(Z)q(Z)q(Z) 变化时是个定值,因此,要让 KL 尽量小就转换为让 ELBO 尽量大,即有:
q^(Z)=argmaxq(Z)L(q)→q(Z)≈P(Z∣X)\hat{q}(Z)=\arg\max_{q(Z)}\mathcal{L}(q)\ \ \ \ \rightarrow\ \ \ \ q(Z)\approx P(Z|X) q^(Z)=argq(Z)maxL(q) → q(Z)≈P(Z∣X)
接下来,我们根据平均场理论,将 q(Z)q(Z)q(Z) 划分为 MMM 个相互独立的份:
q(Z)=∏i=1Mqi(Zi)q(Z)=\prod_{i=1}^Mq_i(Z_i) q(Z)=i=1∏Mqi(Zi)
之后在求解的时候,我们会先固定 q1,q2,…,qj−1,…,qMq_1,q_2,\dots,q_{j-1},\dots,q_Mq1,q2,…,qj−1,…,qM ,然后求解单个分量 qjq_jqj ,最后将所有分量连乘起来,得到完整的 q(Z)q(Z)q(Z) 。
首先先将 q(Z)q(Z)q(Z) 代回到原式中:
L(q)=∫Zq(Z)logP(X,Z)dZ−∫Zlogq(Z)dZ=①−②\mathcal{L}(q)=\int_Zq(Z)\log P(X,Z)dZ-\int_Z\log q(Z)dZ=①-②\\ L(q)=∫Zq(Z)logP(X,Z)dZ−∫Zlogq(Z)dZ=①−②
一项一项地来看:
①=∫Zq(Z)logP(X,Z)dZ=∫Z∏i=1Mqi(Zi)logP(X,Z)dZ=∫Zjqj(Zj)∫Zi(i≠j)∏i≠jMqi(Zi)logP(X,Z)dZi(i≠j)dZj=∫Zjqj(Zj)∫Zi(i≠j)logP(X,Z)∏i≠jMqi(Zi)dZi(i≠j)dZj=∫Zjqj(Zj)⋅E∏i≠jMqi(Zi)[logP(X,Z)]dZj\begin{align} ①&=\int_Zq(Z)\log P(X,Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log P(X,Z)dZ\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\prod_{i\ne j}^Mq_i(Z_i)\log P(X,Z)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\int_{Z_i(i\ne j)}\log P(X,Z)\prod_{i\ne j}^Mq_i(Z_i)dZ_{i(i\ne j)}dZ_j\\ &=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j \end{align} ①=∫Zq(Z)logP(X,Z)dZ=∫Zi=1∏Mqi(Zi)logP(X,Z)dZ=∫Zjqj(Zj)∫Zi(i=j)i=j∏Mqi(Zi)logP(X,Z)dZi(i=j)dZj=∫Zjqj(Zj)∫Zi(i=j)logP(X,Z)i=j∏Mqi(Zi)dZi(i=j)dZj=∫Zjqj(Zj)⋅E∏i=jMqi(Zi)[logP(X,Z)]dZj
- 先将 q(Z)q(Z)q(Z) 进行拆分为 MMM 份;
- 然后将第 jjj 份拆出来;
- 其他份的积分写成期望的形式(见到积分,就考虑能写成期望)
然后看后面一项:
②=∫Zq(Z)logq(Z)dZ=∫Z∏i=1Mqi(Zi)log∏i=1Mqi(Zi)dZ=∫Z∏i=1Mqi(Zi)∑i=1Mlogqi(Zi)dZ=∫Z∏i=1Mqi(Zi)[logq1(Z1)+logq2(Z2)+⋯+logqM(ZM)]dZ\begin{align} ②&=\int_Zq(Z)\log q(Z)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\log\prod_{i=1}^M q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)\sum_{i=1}^M\log q_i(Z_i)dZ\\ &=\int_Z\prod_{i=1}^Mq_i(Z_i)[\log q_1(Z_1)+\log q_2(Z_2)+\dots+\log q_M(Z_M)]dZ\\ \end{align} ②=∫Zq(Z)logq(Z)dZ=∫Zi=1∏Mqi(Zi)logi=1∏Mqi(Zi)dZ=∫Zi=1∏Mqi(Zi)i=1∑Mlogqi(Zi)dZ=∫Zi=1∏Mqi(Zi)[logq1(Z1)+logq2(Z2)+⋯+logqM(ZM)]dZ
- 写成 MMM 份;
- log 里面乘变外面加;
- 把连加号写开;
- 然后我们看其中一项(比如第一项):
∫Z∏i=1Mqi(Zi)⋅logq1(Z1)dZ=∫Zq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ=∫Z1Z2…ZMq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ1dZ2…dZM=∫Z1q1(Z1)logq1(Z1)dZ1∏i=2M∫Ziqi(Zi)dZi=∫Z1q1(Z1)logq1(Z1)dZ1\begin{align} \int_Z\prod_{i=1}^Mq_i(Z_i)\cdot\log q_1(Z_1)dZ&=\int_Zq_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ\\ &=\int_{Z_1Z_2\dots Z_M}q_1(Z_1)q_2(Z_2)\dots q_M(Z_M)\log q_1(Z_1)dZ_1dZ_2\dots dZ_M\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1\prod_{i=2}^M\int_{Z_i}q_i(Z_i)dZ_i\\ &=\int_{Z_1}q_1(Z_1)\log q_1(Z_1)dZ_1 \end{align} ∫Zi=1∏Mqi(Zi)⋅logq1(Z1)dZ=∫Zq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ=∫Z1Z2…ZMq1(Z1)q2(Z2)…qM(ZM)logq1(Z1)dZ1dZ2…dZM=∫Z1q1(Z1)logq1(Z1)dZ1i=2∏M∫Ziqi(Zi)dZi=∫Z1q1(Z1)logq1(Z1)dZ1
- 把 q1(Z1)q_1(Z_1)q1(Z1) 相关的移到一起;
- 剩下的积分全都是 1
②=∑i=1M∫Ziqi(Zi)logqi(Zi)dZi=∫Zjqj(Zj)logqj(Zj)dZj+C\begin{align} ②&=\sum_{i=1}^M\int_{Z_i}q_i(Z_i)\log q_i(Z_i)dZ_i\\ &=\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ \end{align} ②=i=1∑M∫Ziqi(Zi)logqi(Zi)dZi=∫Zjqj(Zj)logqj(Zj)dZj+C
- 有了 i=1i=1i=1 时的表示,我们就把整个第二项写成连加的形式;
- 我们只关心第 jjj 项,其余的视作常数 CCC
这样处理完两项,有:
①−②=∫Zjqj(Zj)⋅E∏i≠jMqi(Zi)[logP(X,Z)]dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logP^(X,Zj)dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logP^(X,Zj)qj(Zj)dZj=−KL(P^(X,Zj)∣∣qj(Zj))≤0\begin{align} ①-②&=\int_{Z_j}q_j(Z_j)\cdot\mathbb{E}_{\prod_{i\ne j}^Mq_i(Z_i)}[\log P(X,Z)]dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log \hat{P}(X,Z_j) dZ_j-\int_{Z_j}q_j(Z_j)\log q_j(Z_j)dZ_j+C\\ &=\int_{Z_j}q_j(Z_j)\cdot\log\frac{ \hat{P}(X,Z_j)}{q_j(Z_j)}dZ_j\\ &=-KL(\hat{P}(X,Z_j)||q_j(Z_j))\le 0 \end{align} ①−②=∫Zjqj(Zj)⋅E∏i=jMqi(Zi)[logP(X,Z)]dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logP^(X,Zj)dZj−∫Zjqj(Zj)logqj(Zj)dZj+C=∫Zjqj(Zj)⋅logqj(Zj)P^(X,Zj)dZj=−KL(P^(X,Zj)∣∣qj(Zj))≤0
- 将 ① 中的期望写成一个函数的形式:P^(X,Zj)\hat{P}(X,Z_j)P^(X,Zj) ;
- 最后就是一个负的 KL 散度,当 P^(X,Zj)=qj(Zj)\hat{P}(X,Z_j)=q_j(Z_j)P^(X,Zj)=qj(Zj) 时取到等号
Ref
- 机器学习白板推导