【深度学习】VAE(Variational Auto-Encoder)原理
【深度学习】VAE(Variational Auto-Encoder)原理
变分自编码器(VAE)是深度学习领域中一种重要的生成模型,它在图像生成、数据增强等领域有着广泛的应用。本文将从AE与VAE的基本概念出发,深入解析VAE的数学原理,并介绍reparameterization技巧,最后讨论VAE的不足之处。
一、AE与VAE
AE(Auto-Encoder) 是一个应用很广泛的机器学习方法。主要内容即是:将输入(Input)经过编码器(encoder)压缩为一个编码(code),再通过解码器(decoder)将编码(code)解码为输出(Output)。学习的目标即:要使得输出(Output)与输入(Input)越接近越好。
以输入为图像为例,结构图如下:
AE中间阶段生成的编码向量,并不是随机、没有意义的。编码中携带着与输入有关的信息,编码中的某些维度代表着输入数据的某些特征。例如生成人脸图像时,编码可以表示人脸表情、头发样子、是否有胡子等等。
VAE(变分自动编码器) 作为AE的变体,它主要的变动是对编码(code)的生成上。编码(code)不再像AE中是唯一映射的,而是具有某种分布,使得编码(code)在某范围内波动时都可产生对应输出。借助下面这个例子进行理解:
如上图AE示意图,左侧是对满月图像编解码,右侧是对弦月图像编解码,而像中间的编码对解码器来说并不知道要生成何种图像。在VAE示意图中,左右两侧对图像编解码过程中,编码有不同程度的扰动(即图中noise),解码器利用扰动范围内的编码同样可以生成相应的图像,而对交界处的编码,编码器既想生成满月图像,又想生成弦月图像,为此做出折中,生成位于两者之间的图像。
这就是VAE一个较为直观的想法。
二、VAE原理
VAE是一个深度生成模型,其最终目的是生成出概率分布P(x),x即输入数据。
在VAE中,我们通过高斯混合模型(Gaussian Mixture Model)来生成P(x),也就是说P(x)是由一系列高斯分布叠加而成的,每一个高斯分布都有它自己的参数μ和σ。
那我们借助一个变量z∼N(0,I)(注意z是一个向量,生成自一个高斯分布),找一个映射关系,将向量z映射成这一系列高斯分布的参数向量μ(z)和σ(z)。有了这一系列高斯分布的参数我们就可以得到叠加后的P(x)的形式,即x∣z∼N(μ(z),σ(z))。(这里的“形式”仅是对某一个向量z所得到的)。
那么要找的这个映射关系P(x∣z)怎么获得呢?就拿神经网络来做呗,只要神经元足够想要啥样的函数得不到呢。如下图形式:
输入向量z,得到参数向量μ(z)和σ(z)。这个映射关系是要在训练过程中更新NN权重得到的。这部分作用相当于最终的解码器(decoder)。
对于某一个向量z我们知道了如何找到P(x)。那么对连续变量z依据全概率公式有:
P(x)=∫zP(z)P(x∣z)dz
但是很难直接计算积分部分,因为我们很难穷举出所有的向量z用于计算积分。又因为P(x)难以计算,那么真实的后验概率P(z∣x)=P(z)P(x∣z)/P(x)同样是不容易计算的,这也就是为什么下文要引入q(z∣x)来近似真实后验概率P(z∣x)。
因此我们用极大似然估计来估计P(x),有似然函数L:
L=∑xlogP(x)
这里我们额外引入一个分布q(z∣x),z∣x∼N(μ′(x),σ′(x))。这个分布表示形式如下:
这个分布同样是用一个神经网络来完成,向量z根据NN输出的参数向量μ′(x)和σ′(x)运算得到,注意这三个向量具有相同的维度。这部分作用相当于最终的编码器(encoder)。
之后就开始推导了。
logP(x)=∫zq(z∣x)logP(x)dz
∫zq(z∣x)d
z
1
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
P
(
z
∣
x
)
d
z
∫
z
q
(
z
∣
x
)
log
(
P
(
z
,
x
)
q
(
z
∣
x
)
⋅
q
(
z
∣
x
)
P
(
z
∣
x
)
)
d
z
∫
z
q
(
z
∣
x
)
log
q
(
z
∣
x
)
P
(
z
∣
x
)
d
z
+
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
∣
x
)
)
+
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
⪖
∫
z
q
(
z
∣
x
)
log
P
(
z
,
x
)
q
(
z
∣
x
)
d
z
∵
D
K
L
(
q
∣
∣
P
)
⪖
0
我们将∫zq(z∣x)logP(z,x)q(z∣x)dz称为logP(x)的(variational) lower bound,简称为Lb。最大化Lb就等价于最大化似然函数L。那么接下来具体看Lb,
Lb=∫zq(z∣x)logP(z,x)q(z∣x)dz
=∫zq(z∣x)log(P(z)q(z∣x)⋅P(x∣z))dz
=∫zq(z∣x)logP(z)q(z∣x)dz+∫zq(z∣x)logP(x∣z)dz
=−DKL(q(z∣x)∣∣P(z))+∫zq(z∣x)logP(x∣z)dz
=−DKL(q(z∣x)∣∣P(z))+Eq(z∣x)[logP(x∣z)]
最大化Lb包括下面两部分:
- minimizingD
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
,使后验分布近似值q(z∣x)接近先验分布P(z)。也就是说通过q(z∣x)生成的编码z不能太离谱,要与某个分布相当才行,这里是对中间编码生成起了限制作用。
当q(z∣x)和P(z)都是高斯分布时,推导式有([2]中Appendix B):
D
K
L
(
q
(
z
∣
x
)
∣
∣
P
(
z
)
)
−
1
2
∑
j
J
(
1
+
log
(
σ
j
)
2
−
(
μ
j
)
2
−
(
σ
j
)
2
)
其中J表示向量z的总维度数,σj和μj表示q(z∣x)输出的参数向量σ和μ的第j个元素。(这里的σ和μ等于前文中μ′(x)和σ′(x))
- maximizingEq(z∣x)[logP(x∣z)],即在给定编码器输出q(z∣x)下解码器输出P(x∣z)越大越好。这部分也就相当于最小化Reconstruction Error(重建损失)。
补充点:重建损失函数选择交叉熵损失还是平方差损失,是跟P(x∣z)形式有关的,再取对数似然。知乎回答[6]和专栏[7]中有进行讲解说明。引用[6]中用户Taffy lll的回答:
重建损失的数学形式是对数似然logp(x∣z),它的具体表达式和p(x∣z)相关。一般来说,p(x∣z)的选取和x的取值空间是密切相关的: 如果x是二值图像,这个概率一般用伯努利分布,而伯努利分布的对数似然就是binary cross entropy,可以调各大DL库里的BCE函数;如果x是彩色/灰度图像,这个概率取高斯分布,那么高斯分布的对数似然就是平方差。
由此我们可以得出VAE的原理图:
通常忽略掉decoder输出的σ(x)一项,仅要求μ(x)与x越接近越好。
对某一输入数据x来说,VAE的损失函数即:
minLossVAE=DKL(q(z∣x)∣∣P(z))−Eq(z∣x)[logP(x∣z)]
附:
极大似然估计P(x)的时候还有一种写法,即通过P(x)=∫zP(x,z)dz来推导。如图[3]:
里边有提到术语ELBO,EvidenceLowerBOund(证据下界),有兴趣的可以自行查阅了解(也就是上文提到的变分下界,不过ELBO叫法更普遍)。
三、reparameterization trick
由上文中VAE原理图可以看出,z∼q(z∣x),即编码z是由分布q(z∣x)采样产生,而采样操作是不可微分的,因此反向传播做不了。[2]中提到了reparameterization trick来解决,借助[4]中的示意图理解下:
将上图左图原来的采样操作通过reparameterization trick变换为右图的形式。
我们引入一个外部向量ϵ∼N(0,I),通过z=μ+σ⊙ϵ计算编码z(⊙表示element-wise乘法,ϵ的每一维都服从标准高斯分布即ϵi∼N(0,1)),由此loss的梯度可以通过μ和σ分支传递到encoder model处(ϵ并不需要梯度信息来更新)。
这里利用了这样一个事实[5]:
考虑单变量高斯分布,假设z∼p(z∣x)=N(μ,σ2),从中采样一个z,就相当于先从N(0,1)中采样一个ϵ,再令z=μ+σ⊙ϵ。
最终的VAE实际形式如下图所示:
四、不足
VAE在产生新数据的时候是基于已有数据来做的,或者说是对已有数据进行某种组合而得到新数据的,它并不能生成或创造出新数据。另一方面是VAE产生的图像比较模糊。
而大名鼎鼎的GAN利用对抗学习的方式,既能生成新数据,也能产生较清晰的图像。后续的更是出现了很多种变形。
五、参考文献
[1]Unsupervised Learning: Deep Generative Model (2017/04/27)
[2] VAE原著Auto-encoding variational bayes
[3]VAE的三种不同推导方法
[4]https://www.jeremyjordan.me/variational-autoencoders/
[5]变分自编码器VAE:原来是这么一回事
[6]变分自编码器的重建损失为什么有人用交叉熵损失?有人用平方差?
[7]再谈变分自编码器VAE:从贝叶斯观点出发
[8]posterior collapse 后验消失问题是什么
本文原文来自CSDN,作者NooahH,原文链接:https://blog.csdn.net/NooahH/article/details/104676242