扩散模型(Diffusion Model)原理讲解 数学公式推导 简洁易懂版
扩散模型(Diffusion Model)原理讲解 数学公式推导 简洁易懂版
什么是扩散模型?
每接触一个新的模型都会问的问题,GPT的解释是这样的:
Diffusion Model 是近年来在生成式建模领域取得显著进展的一类概率模型,主要用于生成高质量的样本,比如图像、音频、视频等数据类型。它基于对数据分布的逐步逼近与还原,结合了理论上的优雅性和实际应用的高性能。
其核心思想是通过一系列逐步增加噪声的过程将数据分布转换成一个简单的先验分布(如高斯分布),再训练一个模型来逆转这个过程,即从噪声中逐渐恢复出原始的数据分布。
扩散模型的训练目标通常是优化一个损失函数,这个损失函数衡量的是模型预测的噪声与实际添加的噪声之间的差异。训练过程中,模型学习到的是如何从噪声数据中恢复出原始数据的方法,因此能够用于生成新的数据样本。
可以理解为模拟的是一个正向扩散过程和反向生成过程。在正向过程中,是将噪声逐步加入到数据中,最终转化为近似高斯分布的情况;在反向过程中,是从纯噪声逐步还原数据,直到生成近似真实样本的结果。
公式推导
前向过程
αt=1−βt
这里的β会逐渐变大,从0.0001到0.002,对应的α就会逐渐变小,在代码中是直接在这个范围内等间隔采样,随着迭代次数的增加,β增大。
xt=αtxt−1+1−αtzt(1)
这个式子描述前向过程中由xt−1得xt。可以看到在开始时候,只加一点噪声,后来越加越多,直到近似成为全噪声的图像。
现在需要解决的问题:对于整个序列,一个一个计算太费事,对任意时刻的Xt能不能直接由X0计算得来?
xt−1=αt−1xt−2+1−αt−1zt(2)
将这个(2)式带入(1)式,有:
xt=αt(αt−1xt−2+1−αt−1zt2)+1−αtzt1(3)
目前已知的是每次加入的噪声z1,z2等都服从高斯分布N(0,I)
将(3)式展开:
xt=αtαt−1xt−2+(αt(1−αt−1)zt2+1−αtzt1)(4)
xt=αtαt−1xt−2+1−αtαt−1z(5)
已知(4)中的z1,z2分别服从N(0,1−αt),N(0,at(1−αt−1))。
有性质N(0,σ12I)+N(0,σ22I)∼N(0,(σ12+σ22)I),所以(5)式中的z的方差是1−αt+at(1−αt−1)=1−αtαt−1
观察(5)式,可以发现xt和xt−2的关系可以迭代推广,一直迭代到x0,得到:
xt=αtx0+1−αtzt(其中αt指的是累乘,就是αt∗αt−1∗αt−2∗∗α1)(6)
上面的式子说明对于扩散模型的前向过程,可根据x0直接得到任意时刻的分布。到此,前向过程可以进行。
反向过程
根据上面的逆向图例,要根据XT求XT−1,就需要知道q(xT−1∣xT)
有贝叶斯公式可知,q(xT−1∣xT)=q(xT∣xT−1)q(xT−1)q(xT)
结合正向过程中的x0,有q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt−1∣x0)q(xt∣x0)
上式左侧的三个部分都可求,对应的q(xt−1∣xt,x0)也可求。
q(xt−1∣x0)a‾t−1x0+1−a‾t−1z∼N(a‾t−1x0,1−a‾t−1)
q(xt∣x0)a‾tx0+1−a‾tz∼N(a‾tx0,1−a‾t)
q(xt∣xt−1,x0)atxt−1+1−atz∼N(atxt−1,1−at)
上面第一个式子可以理解为直接根据x0求xt−1,第二个式子是直接根据x0求xt,直接套用上面的公式(6)即可,而对第三个单纯就是式子(1)。
已知对于高斯分布N(μ,σ),有概率密度函数p(x)=12πσe−(x−μ)22σ2
即N(μ,σ)∝e−(x−μ)22σ2,可以根据上面部分的高斯分布得到我们要求解的q(xT−1∣xT)正比的形式。
q(xt−1∣xt,x0)∝exp(−12((xt−αtxt−1)2βt+(xt−1−αˉt−1x0)21−αˉt−1−(xt−αˉtx0)21−αˉt))
根据对应关系,可以看到上面图片中的红色部分括号中对应分布的方差。又因为在最原始的模型中,αt和βt都是固定已知的,所以方差已知。上图中蓝色部分对应可以求解均值,求出的均值:
μ~t(xt,x0)=αt(1−αˉt−1)1−αˉtxt+αˉt−1βt1−αˉtx0
但是目前存在的问题:X0就是要反向过程要求解的状态。
这个时候正向过程的(6)式就可以拿来替换了,由(6)式得到x0=1αˉt(xt−1−αˉtzt)
最终结果:μ~t=1at(xt−βt1−a‾tz)
目前已知方差和均值,可以将反向的过程一步一步进行下去了。
但是目前又出现问题:上式中的zt用数学方法始终没办法求,所以只能借助于模型训练,通过模型预测在某时刻t的噪声。(ps:一顿操作猛如虎,最后还是需要神经网络出手 -_-)
模型训练
上面提到了,其实就是对每一步的噪声进行拟合,模型的训练需要标签,在扩散模型中,正向过程加噪的过程中,自己加入的噪声肯定是已知的。那么在反向的过程中,关注模型预测出来的噪声和原来加入的噪声之间的差异,尝试最小化两者之间差异就可以进行训练。
算法解读
训练阶段
#2对于某一个特定的分布q(x0),在该分布中进行采样(大致可以理解为,比如全是猫的数据集,全是狗的数据集,在这个特定的数据集里面进行采样得到x0)
#3对应的t是在1到T这个范围内随机选的。在同一个batch中,每个图片对应的t也不一致
#4前提:噪声需要服从标准正态分布
#5模型的训练,其中是指要训练的模型,模型的输入就是图中框出来的部分,也就是XT,模型的输入同时还包含 t,就是把时刻也输入到了模型中,在实际操作中,会根据 t 生成一个向量(正弦位置嵌入),作为轮数的编码。用模型预测值不断拟合真实值,通过Unet这个框架学习到噪音的信息。
采样阶段
#1xT 是随机采样的,看作高斯分布
#2#3#4#5做循环,从 xT 一直循环到 x_1,逐步从全噪音图还原成想要的图片。使用的是推导的公式,配合已经训练的模型,可以实现想要的效果。
#6得到最终的x0