DDPM 模型将一张图片解构为 T 步,从原始的图片 x0 开始,经过 T 步 “ 分解 ” 得到随机杂乱的噪声 xt,即:
=x0→x1→x2→⋯→xT−1→xT=z
所以如果我们能够学会 xt→xt−1 步骤,则我们就可以从噪声恢复原始的图片。所以我们想要学习关系 xt−1=μ(xt),那我们从 xt 出发,反复执行 xt−1=μ(xt) 就能从中恢复。
DDPM 将图片分解的过程描述为:
xt=αtxt−1+βtεt,εt∼N(0,I)
其中有 αt,βt>0,并且有 αt2+β2=1,而 βt 通常接近于 0,可以形象的理解为对于原图的破坏程度,噪声 εt 的引入代表着对于原图的破坏。
反复执行这个分解步骤,我们可以得到:
xt=αtxt−1+βtεt=αt(αt−1xt−2+βt−1εt−1)+βtεt=⋯=(αt⋯α1)x0+多个相互独立的正态噪声之和 (αt⋯α2)β1ε1+(αt⋯α3)β2ε2+⋯+αtβt−1εt−1+βtεt
如果有 n 个随机变量 X1,X2,…,Xn。它们的方差分别为 Var(X1),Var(X2),…,Var(Xn)。并且它们两两之间的协方差为 Cov(Xi,Xj),那么它们的线性组合 Y=α1X1+α2X2+⋯+αnXn 的方差可以用下面公式计算:
Var(Y)=i=1∑nj=1∑nαiαjCov(Xi,Xj)当 X1,X2,…,Xn 相互独立时,它们的协方差矩阵为对角线矩阵,即 Cov(Xi,Xj)=0(i=j),因此上述公式化简为:
Var(Y)=i=1∑nαi2Var(Xi)
为什么我们限定
αt2+βt2=1注意到公式后面部分为多个均值为 0,方差分别为 (αt⋯a2)2β12、(αt⋯α3)2β2、…、(αt)2βt−12、βt2 的正态噪声之和,因为多个独立的正态分布之和还是均值为 0,方差为各个分布方差和的正态分布;
====(αt…α1)2+(αt…α2)2β12+(αt…α3)2β22+⋯+αt2βt−12+βt2(绿色是我们加入的一项)(αt…α2)2α12+(αt…α2)2β12+(αt…α3)2β22+⋯+αt2βt−12+βt2(αt…α2)2(α12+β12)+(αt…α3)2β22+⋯+αt2βt−12+βt2(αt…α3)2(α22(α12+β12)+β22)+⋯+αt2βt−12+βt2αt2(αt−12(…(α22(α12+β12)+β22)+…)+βt−12)+βt2可以发现,如果我们加一个约束将会极大简化这个式子,即要求 αt2+βt2=1,这样上面的平方和就可以等于 1 了。
于是:
xt=记为 αˉt(αt⋯α1)x0+记为 βˉt1−(αt⋯α1)2εt,εt∼N(0,I)
这就给计算 xt 带来了极大的便利。另一方面,DDPM 会选取合适的 αt 形式,使得 aˉ≈0。这意味着经过 T 步分解之后,图像所剩的 语义
可以忽略不计了,已经全部转化为随机噪声 εˉt 了。
分解是 xt−1→xt 的过程,这个过程中有很多数据对 (xt−1,xt),这样重建自然就是从这些数据中学习一个 xt→xt−1 的模型。设该模型为 μ(xt),那么很容易想到学习方案就是最小化二者的欧式距离:
∥xt−1−μ(xt)∥2
注意到分解时 xt=αtxt−1+βtεt,于是 xt−1=αt1(xt−βtεt),所以我们不如直接学习预测噪声。则模型可以设计为:
μ(xt)=αt1(xt−βtϵθ(xt,t))
的形式,其中的 θ 为模型的参数,将其带入到损失函数,得到
==∥xt−1−μ(xt)∥2αt1(xt−βtεt)−αt1(xt−βtϵθ(xt,t))2αt2βt2∥εt−ϵθ(xt,t)∥2
前面的系数代表 loss 的权重,可以先不考虑。我们可以给出 xt 的表达式:
xt=αtxt−1+βtεt=αt(αˉt−1x0+βˉt−1εt−1)+βtεt=αˉtx0+αtβˉt−1εt−1+βtεt
得到损失函数的形式为:
εt−ϵθ(αˉtx0+αtβˉt−1εt−1+βtεt,t)2
为什么要回推到
xt−1你可能会产生疑问:为什么我们需要利用 xt−1? 为什么我们不直接用
xt=αˉtx0+βˉtεˉt让我们先看看,如果这样写,损失函数会是什么样子吧:
=∥εt−ϵθ(xt,t)∥2εt−ϵθ(αˉtx0+βˉtεˉt,t)这个时候,问题就发生了,我们知道 εˉt 和 εt 其实不是相互独立的,所以我们不能在已经采样 εt 情况下完全独立地采样 εˉt
降低方差
上面的损失函数已经可以用于模型的训练了,但是在实践过程中会发现收敛过慢等问题。其原因是
上面需要我们采样的随机变量太多了。
- 随机采样样本
- 从正态分布 N(0,I) 中采样两个随机变量 εˉt,εt
- 从 1∼T 采样一个 t
观察上面损失函数 αtβˉt−1εt−1+βtεt 部分。嘿,这不也是一个正态分布吗,我们可以合成为一个随机变量。
αtβˉt−1εt−1+βtεt=βˉtε∣ε∼N(0,I) 显然新随机变量的均值为 0,方差为各个随机变量方差的和
===(αtβˉt−1)2+βt2αt2(1−(αt−1⋯α1)2)+βt21−(αt⋯α1)2βˉt2
同理可以得到 βtεt−1−αtβˉt−1εt=βˉtω∣ω∼N(0,I),并且可以验证 E[εωT]=0。所以这是两个互相独立的正态随机变量。
于是我们令 αtβˉt−1εt−1+βtεt=βˉtε,那我们的想法就是如何将 εt 利用 ω,ε 表达出来:
求解
εt联立方程
{αtβˉt−1εt−1+βtεt=βˉtεβtεt−1−αtβˉt−1εt=βˉtω解得:
εt=βt2+αt2βˉt−12(βtε−αtβˉt−1ω)βˉt=βˉtβtε−αtβˉt−1ω
将结果带入到上面损失函数式子
=Eεt−1,εt∼N(0,I)[εt−ϵθ(αˉtx0+αtβˉt−1εt−1+βtεt,t)2]Eω,ε∼N(0,I)[βˉtβtε−αtβˉt−1ω−ϵθ(αˉtx0+βˉtε,t)2]
为了方便求均值,我们有 ϵθ:=ϵθ(αˉtx0+βˉtε,t),于是整个平方可以有:
∥⋅∥2==(βˉtβtε)2−βˉt22βtεαtβˉt−1ω+βˉt2(αtβˉt−1ω)2−βˉt2βtεϵθ+βˉt2αtβˉt−1ωϵθ+ϵθ2(整理一下)=:A1(βˉtβtε)2−=:A2βˉt22αtβˉt−1βtεω+=:A3(βˉtαtβˉt−1ω)2−=:A4βtˉ2βtεϵθ+=:A5βˉt2αtβˉt−1ωϵθ+=:A6ϵθ2下面我们分别分析几个部分
- A1: 均值为常数.
- A2: 独立的两个随机变量乘积的均值等于均值的乘积,即有:E[A⋅B]=E[A]⋅E[B],所以该项依然为 0.
- A3: 均值为常数.
- A4: 由于 ϵθ 是含 ε 的非线性函数,非独立的,所以拆解不了.
- A5: 由于 ϵθ 是不含 ω 的;所以该项乘积两边依然是独立的,所以该项均值为 0.
- A6: 无法判断
于是我们可以放心大胆的将 ω 项从损失函数中移除。于是我们得到:
βˉt2βt2Eε∼N(0,I)[ε−βtβˉtϵθ(αˉtx0+βˉtε,t)2]+ 常数
忽略掉常数项,以及系数,得到:
ε−βtβˉtϵθ(αˉtx0+βˉtε,t)2
θminE(x0,c)∼(X0,C)[Et∼U(0,T),ε∼N(0,I)[ε−βtβˉtϵθ(αˉtx0+βˉtε,c,t)2]],
在训练完毕后,我们从一个随机噪声 xt∼N(0,I),执行 T 步下式生成
xt−1=αt1(xt−βtϵθ(xt,t))
如果要 Random Sample,那么需要加上噪声项:
xt−1=αt1(xt−βtϵθ(xt,t))+σtz,z∼N(0,I)
为什么这里需要添加噪声项,并且该噪声项相比于降噪项大的多?
相关资料