狗都能看懂的DDPM论文详解
狗都能看懂的DDPM论文详解
DDPM(Denoising Diffusion Probabilistic Models)是扩散模型的一种,在视觉领域是属于生成式的模型。这篇文章将从扩散模型的基本概念出发,逐步深入讲解DDPM的核心原理,包括Diffusion阶段和Reverse阶段的具体原理与公式推导,最后总结DDPM的算法流程。
DDPM/扩散模型是什么
DDPM(Denoising Diffusion Probabilistic Models)是扩散模型的一种,在视觉领域是属于生成式的模型。
扩散模型(Diffusion Model)的概念最早可以追溯到统计物理学中的玻尔兹曼机(Boltzmann Machines)。这是一种基于能量函数的概率模型,其目的是通过初始状态的蒙特卡洛采样,通过不断更新样本状态,使系统状态分布逐渐接近目标分布。而后科学家又在非均衡热力学的研究中发现,从一个初态经过一系列中间状态最终达到稳定状态,这与扩散模型通过一系列迭代过程从初始状态演化到目标分布的思想相契合。
扩散模型中最重要的思想根基是马尔可夫链,它的一个关键性质是平稳性。即如果一个概率随时间变化,那么在马尔可夫链的作用下,它会趋向于某种平稳分布,时间越长,分布越平稳。如图所示,当你向一滴水中滴入一滴颜料时,无论你滴在什么位置,只要时间足够长,最终颜料都会均匀的分布在水溶液中。这也就是扩散模型的前向过程。
既然颜料均匀分布在溶液中这个过程是必然会发生的,也就是说,通过某种方式,我们可以将其恢复成原来的状态。假设我们将颜料想象成一个噪声,它可以是一个任意的正态分布,我们在不断对一张图片逐步添加噪声的过程就可以看作是一个扩散过程。当添加噪声的次数足够多的时候,它已经变成了一个完全为噪声的图像,也就是我们说的稳定的状态,那么反向去噪的这个过程就是可逆的。那么给定一个神经网络,它只要预测出噪声,就能逐步将图像恢复。这个过程也被称为Stable Diffusion。
Stable Diffusion 分为 Diffusion 和 Reverse 两个阶段。其中 Diffusion 阶段通过不断地对真实图片添加噪声,最终得到一张噪声图片。而 Reverse 阶段,模型需要学习预测出一张噪声图片中的噪声部分,然后减掉该噪声部分,即:去噪。随机采样一张完全噪声图片,通过不断地去噪,最终得到一张符合现实世界图片分布的真实图片。以下分别介绍两个阶段的具体原理与公式推导。
Diffusion 阶段
这个阶段就是不断地给真实图片加噪声,经过T步加噪之后,噪声强度不断变大,得到一张完全为噪声的地图像。整个扩散过程可以近似看成一次加噪即变为噪声图,那么其实我们只需要搞清楚其中一步加噪就可以了,也即搞清楚Xt=f(Xt−1)中的f(x)的过程。
f(x)在论文的公式中有明确的定义:
Xt=1−βt∗Xt−1+βt∗ZtZt∼N(0,I)
t是时间序列中一个值,取值范围为[0,T],Zt是对应时间产生的随机噪声,βt是超参数,也是序列中的一个值,在论文的实验部分,其经验值范围是[10−4,0.02]线性变化,而且一般来说,t越大,βt的取值也就越大(一开始,加一点点噪声就能比较明显的看出和原图的区别,越到后面,图像退化的越厉害,轻微的扰动已经看不出明显的变化,所以βt的值需要更大)
训练时,这样逐步加噪声效率太低了。想要提高训练效率。那么既然最终都会扩散成一个稳定的状态,那么是否我们可以实现从X0直接扩散成XT呢?答案是可以的。
首先,这里先做一个简单的变化,αt=1−βt,那么Xt就变为:
Xt=αt∗Xt−1+1−αt∗Zt
既然要从X0求到XT,那我们一步一步求,其中:
Xt−1=αt−1∗Xt−2+1−αt−1∗Zt−1
将Xt−1代入到Xt的公式中,得:
Xt−1=αt(αt−1∗Xt−2+1−αt−1∗Zt−1)+1−αt∗Zt=αtαt−1∗Xt−2+αt−αtαt−1∗Zt−1+1−αt∗Zt
Zt、Zt−1是从均值为0,方差为单位矩阵的正态分布的两次独立采样,所以:
αt−αtαt−1∗Zt−1∼N(0,(αt−αtαt−1)∗I)1−αt∗Zt−1∼N(0,(1−αt)∗I)
二者相加,即为方差相加,得:
N(0,(1−αtαt−1)∗I)
所以Xt−1的公式可以写成:
Xt−1=αtαt−1∗Xt−2+1−αtαt−1∗ZZ∼N(0,I)
那其实足以看出,从X0推导至任意Xt,有:
Xt=αtαt−1…α1∗X0+1−αtαt−1…α1∗ZZ∼N(0,I)
简写一下:
Xt=αt‾∗X0+1−αt‾∗Z(1)Z∼N(αt∗X0,(1−αt‾)∗I)
其中αt‾代表累乘。
当t很大的时候,Xt≈Z(全为噪声),αt‾为0,βt此时应比较大,也符合我们一开始的所给出来的结论。
diffusion阶段的总结:
核心公式(从X0一次扩散到Xt):Xt=αt‾∗X0+1−αt‾∗ZZ∼N(αt∗X0,(1−αt‾)∗I)
其中某一步(从Xt−1扩散到Xt):Xt=αt∗Xt−1+1−αt∗ZtZ∼N(αt∗Xt−1,(1−αt)∗I)
Reverse阶段
我们先来看一下整个reverse阶段是在做什么,首先取出batch size大小的t,然后针对每个image做diffusion,将我们得到的noise图像放到UNet网络预测噪声Z(指代图中Z′),然后用noise信息预测多余的噪声Z即可。所以整个ddpm,需要训练的就是一个预测噪声的网络,使得预测出来的噪声与实际加的噪声越接近越好。对比GAN网络,不难发现GAN是需要训练2个模型,训练过程极其不稳定,有时候生成器训好了,判别器却没训好,以至于loss都不能真实的反映网络的性能。而ddpm只需要训练一个网络,相比之下稳定很多。
在加噪声的过程中,我们为了减少计算消耗,算出了一次扩散的公式,理论上我们也可以得到一次减噪的公式:
X0=(Xt−1−αt‾Z~)αt‾(2)
论文中的结论可以知道,这么做的效果比较差,图片是很模糊的,不符合逆扩散的过程,最好还是一步一步推。先根据Xt预测出Z~,求出Xt−1,然后逐步逐步得到X0,这个过程如下图所示:
现在我们知道Z=UNet(Xt,t)计算得出,整个reverse过程中,就只剩下xt−1=f(Xt,Z)的f(x)这个过程还没搞清楚了。
我们要推理Xt→Xt−1的过程,相当于已知Xt的概率,去求Xt−1的条件概率,即计算q(Xt−1∣Xt),根据贝叶斯公式,有:
q(Xt−1∣Xt)=q(Xt,Xt−1)q(Xt)=q(Xt∣Xt−1)q(Xt−1)q(Xt)
那么同样用条件概率的方式去等价(具体公式和其服从的正态分布见上文):
- Xt−1→Xt可以用q(Xt∣Xt−1)表示
- X0→Xt可以用q(Xt)、q(Xt−1)表示
那么如果将所有的概率都用正态分布表示:
- q(Xt∣Xt−1)∼N(αt∗Xt−1,(1−αt)∗I)
- q(Xt)∼N(αt‾∗X0,(1−αt‾)∗I)
- q(Xt−1)∼N(αt−1‾∗X0,(1−αt−1‾)∗I)
而在已知高斯分布的均值和方差时,有正比关系:N(μ,σ2)∝exp(12∗(x−μ)2σ2), 将上面几个高斯分布的均值和方差分别代入(分子上相加,分母上相减),得:
q(Xt∣Xt−1)q(Xt−1)q(Xt)∝exp{−12(∗(xt−αt∗Xt−1)21−αt+(xt−1−αt−1‾∗X0)21−αt−1−(xt−αt‾∗X0)21−αt‾)}
但别忘了,我们最初的目标是求分布q(Xt−1∣Xt),也即求Xt−1,可以观察到,目前我们推导的结果是一个Xt−1的二项式,将其配方,找到我们关心的q(Xt−1∣Xt)的均值和方差。我们对上式进一步简化:
q(Xt−1∣Xt)∝exp{−12(αtβt+11−αt−1‾∗Xt−12)−2(αt∗Xtβt+αt−1‾∗X01−αt−1‾)∗Xt−1+?}
最后边的常量是什么不重要,我们只关心均值和方差,所以利用变量A、B替代,简化求解过程:
exp∝−12(A∗Xt−12)−2B∗Xt−1+Cexp∝{−12A(Xt−1+B2A)2+C}
由此可得均值和方差表示为:
μ=−B2Aσ2=1A
而A与B是替代的变量,为:
A=αtβt+11−αt−1‾B=αt∗Xtβt+αt−1‾∗X01−αt−1‾
代入计算,方差为:
σ2=1A=1/(αtβt+11−αt−1‾)=1/(αt−αt∗αt−1‾+βtβt∗(1−αt−1‾))=1−αt−1‾1−αt‾∗βt
均值为:
μ=−B2A=(αt∗Xtβt+αt−1‾∗X01−αt−1‾)∗αt−1‾1−αt‾∗βt=αt∗1−αt−1‾1−αt‾∗Xt+αt−1‾∗βt1−αt‾∗X0
所以为什么说逆扩散的时候,一步一步推是更准的,因为这个地方的X0是估计出来的,里面含有Z=UNet(Xt,t),由于这个值每一次都是当前步估计的结果,而它本身由Xt和Z计算得来,这两个值,t越小,占比也就越小,噪声越小,估计也就越准。
从结果可以看出,均值和方差都是由已知的α、β计算出来的,这些都是我们预设好的超参数,而Xt、X0又是之前公式(1)(2)中以求的了,继续将X0待入,可得:
μ=αt∗1−αt−1‾1−αt‾∗Xt+αt−1‾∗βt1−αt‾∗(Xt−1−αt‾∗Z)αt‾=Xtαt∗(αt−αt‾+βt1−αt‾)+Zαt∗βt1−αt‾=1αt∗(Xt−βt1−αt‾∗Z~)
由于我们刚刚说过:
q(Xt−1∣Xt)=Xt−1∼N(μ,σ2∗I)
所以最终得到的结论就是:
q(Xt−1∣Xt)∼N(1αt(Xt−βt1−αt‾∗Z~,1−αt−1‾1−αt‾∗βt)
利用重参数化技巧(高斯分布里面写的是方差,乘的是标准差,要加上一个根号),得:
Xt−1=1αt(Xt−βt1−αt‾∗Z~)+1−αt−1‾1−αt‾∗βt)∗Z
其中
Z~=UNet(Xt,t)Z∼N(0,I)
这里的Z采样也是希望在重建的过程中,能添加一些不确定性,不至于每一次重建的结果都是由UNet决定。从公式里面也可以看出,Xt−1也是从Xt减去Z~,将预测噪声移除。
这里补充说明一下,重参数的过程,假设从某个正态分布N(μ,σ2∗I)采样一个X的话,它可以等价于,从一个标准正态分布N(0,I)去采样一个Z,然后利用Z去生成X:
X=μ+σ∗Z∼N(μ,σ2∗I)
总结
最后我们对照一下DDPM中的给出的算法流程。
训练过程中,我们对每个x都会采样出一个t,然后根据t,生成对应的噪音ϵ,我们的UNet网络需要预测的就是这个噪声,它的参数被记作θ,这里额外说明一下,也并不是一定要用UNet,只是这个网络结构资源消耗和适用性更好。当训练完成之后,我们就有了一个去噪网络。
采样过程中,我们是没有任何真实图像的,所以我们需要从一个标准正态分布中采样一个XT,这是我们采样的起点,接下来,我们会对它做T步的reverse,一直推到X0,这里算法还有个小细节,只有t>1的时候,z才需要采样,否则它就是0,当t=0时,我们想求的就是真实的X0,这时候就不需要加扰动了,它必须是个确定的结果。相当于均值给定的是一个确定的生成方向,方差和噪声给定的是一个不确定的方向。另外呢,从训练经验来看,这个扰动值也不需要和推理结果完全一样,论文只是提供了这个扰动强度的上界σ,比他小甚至为 0,也是可以的。