Skip to main content

生成扩散模型 - 条件控制生成

· 3 min read
PuQing
AI, CVer, Pythoner, Half-stack Developer

从生成手段上看,条件控制生成有两种:事后修改 (Classifier-Guidance) 和事前训练 (Classifier-Free)。

利用已经训练好的生成模型,通过一个分类器来调控生成过程,这就是事后修改的方法,因为从头到位训练一个生成模型训练成本太大了。而对于大公司来说,不缺算力,所以一般采用的是在训练过程中加入训练信号,达到更好的训练生成效果,这就是 Classifier-Free 方案。

条件输入

生成模型最关键的就是对于 p(xt1xt)p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t}) 的建模,而条件生成就是以条件 y\boldsymbol{y} 作为条件输入,而这时的条件概率分布就可以写为 p(xt1xt,y)p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t},\boldsymbol{y})。为了重用已经训练好的无条件生成模型 p(xt1,xt)p(\boldsymbol{x}_{t-1},\boldsymbol{x}_{t}),我们利用贝叶斯定理:

p(xt1y)=p(xt1)p(yxt1)p(y)p(\boldsymbol{x}_{t-1}\mid\boldsymbol{y})= \frac{p(\boldsymbol{x}_{t-1})p(\boldsymbol{y}\mid \boldsymbol{x}_{t-1})}{p(\boldsymbol{y})}

补上 xt\boldsymbol{x}_{t} 条件:

p(xt1xt,y)=p(xt1xt)p(yxt1,xt)p(yxt)p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t}, \boldsymbol{y})= \frac{p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t}) p(\boldsymbol{y}\mid \boldsymbol{x}_{t-1},\boldsymbol{x}_{t})}{p(\boldsymbol{y}\mid \boldsymbol{x}_{t})}

注意到 xt\boldsymbol{x}_{t}xt1\boldsymbol{x}_{t-1} 加噪声得到的,我们假定对条件概率 p(yxt1)p(\boldsymbol{y}\mid \boldsymbol{x}_{t-1}) 添加条件 xt\boldsymbol{x}_{t} 是不影响分类结果的,即 p(yxt1,xt)=p(yxt1)p(\boldsymbol{y}|\boldsymbol{x}_{t-1},\boldsymbol{x}_{t})=p(\boldsymbol{y}\mid \boldsymbol{x}_{t-1})。从而

p(xt1xt,y)=p(xt1xt)p(yxt1)p(yxt)=p(xt1xt)elogp(yxt1)logp(yxt)p\left(x_{t-1} \mid x_{t}, y\right)=\frac{p\left(x_{t-1} \mid x_{t}\right) p\left(y \mid x_{t-1}\right)}{p\left(y \mid x_{t}\right)}=p\left(x_{t-1} \mid x_{t}\right) e^{\log p\left(y \mid x_{t-1}\right)-\log p\left(y \mid x_{t}\right)}

近似分布

我们利用泰勒展开近似的求解差值:

logp(yxt1)logp(yxt)(xt1xt)xtlogp(yxt)\log p\left(y \mid x_{t-1}\right)-\log p\left(y \mid x_{t}\right) \approx\left(x_{t-1}-x_{t}\right) \cdot \nabla_{x_{t}} \log p\left(y \mid x_{t}\right)

而对于条件概率分布 p(xt1xt)p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t}) 有:

p(xt1xt)=N(xt1;μ(xt),σt2I)ext1μ(xt)2/2σt2p\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}\right)=\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \boldsymbol{\mu}\left(\boldsymbol{x}_{t}\right), \sigma_{t}^{2} \boldsymbol{I}\right) \propto e^{-\left\|\boldsymbol{x}_{t-1}-\boldsymbol{\mu}\left(\boldsymbol{x}_{t}\right)\right\|^{2} / 2 \sigma_{t}^{2}}

于是我们可以进一步得到

p(xt1xt,y)ext1μ(xt)2/2σt2+(xt1xt)xtlogp(yxt)ext1μ(xt)σt2xtlogp(yxt))2/2σt2\begin{aligned} p\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, \boldsymbol{y}\right) & \propto e^{-\left\|\boldsymbol{x}_{t-1}-\mu\left(\boldsymbol{x}_{t}\right)\right\|^{2} / 2 \sigma_{t}^{2}+\left(\boldsymbol{x}_{t-1}-\boldsymbol{x}_{t}\right) \cdot \nabla_{x_{t}} \log p\left(\boldsymbol{y} \mid \boldsymbol{x}_{t}\right)} \\ & \propto e^{\left.-\| \boldsymbol{x}_{t-1}-\mu\left(\boldsymbol{x}_{t}\right)-\sigma_{t}^{2} \nabla_{x_{t}} \log p\left(\boldsymbol{y} \mid \boldsymbol{x}_{t}\right)\right) \|^{2} / 2 \sigma_{t}^{2}} \end{aligned}

所以 p(xt1xt,y)p(\boldsymbol{x}_{t-1}\mid \boldsymbol{x}_{t},\boldsymbol{y}) 可以近似为 N(xt1;μ(xt)+σt2xtlogp(yxt),σt2I)\mathcal{N}\left(\boldsymbol{x}_{t-1} ; \boldsymbol{\mu}\left(\boldsymbol{x}_{t}\right)+\sigma_{t}^{2} \nabla_{\boldsymbol{x}_{t}} \log p\left(\boldsymbol{y} \mid \boldsymbol{x}_{t}\right), \sigma_{t}^{2} \boldsymbol{I}\right)

xt1=μ(xt)+σt2xtlogp(yxt)新增项 +σtε,εN(0,I)\boldsymbol{x}_{t-1}=\boldsymbol{\mu}\left(\boldsymbol{x}_{t}\right)+\underbrace{\sigma_{t}^{2} \nabla_{x_{t}} \log p\left(y \mid x_{t}\right)}_{\text {新增项 }}+\sigma_{t} \boldsymbol{\varepsilon}, \quad \boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0}, \boldsymbol{I})