深入探索VQ-VAE:原理、架构与应用实践全解析
深入探索VQ-VAE:原理、架构与应用实践全解析
VQ-VAE(Vector Quantized Variational Autoencoder)是一种结合了变分自编码器(VAE)和向量量化(Vector Quantization)技术的生成模型。它通过引入离散编码的概念,解决了传统VAE在图像生成任务中的一些局限性。本文将深入探讨VQ-VAE的原理、架构及其与传统VAE和AE的区别。
1. VQ离散编码思路
自编码器(AE)无法进行图像生成,因为它过于拟合数据。为了解决这个问题,VAE引入了噪声,使得编码结果在一个正态分布上采样,从而实现图像生成。然而,VAE通过平衡重构损失和KL散度来工作,这种平衡很难把握。
VQ-VAE通过引入离散编码的概念来克服这一问题。离散化编码更有意义,因为使用PixelCNN网络让这个离散化编码有了实际意义。作者假设先验的离散分布p是均匀分布,但实际生成任务中并没有在均匀分布上采样,而是引入了PixelCNN使得整个离散分布有了意义。这也是为什么说VQ-VAE更像一个AE而不是VAE的主要原因。
1.1 整体过程
论文中很多内容没有解释清楚,包括损失函数的推导过程。我们通过以下图示来解释VQ-VAE的工作流程:
Encoder部分:通过卷积神经网络(CNN)将图像编码到一个m x m x D的维度上,保留了图像的空间信息。这个结果是$z_e(x)$,与AE的Encoder过程类似,但这里保留了空间信息的m*m个D维向量。
$z_e(x) \rightarrow z_q(x)$通过embedding space映射:
- embedding space维护了一组K个可训练的参数向量$E = (e_1, e_2, ..., e_k)$;
- 通过最近邻算法从E中找到最相近的下标,假设压缩后的维度是22D,并假设找到最近邻的下标是$\begin{bmatrix} 0 & 1 \ 2 & 2 \ \end{bmatrix}$,则$z_q(x)$为$\begin{bmatrix} e_1 & e_2 \ e_3 & e_3 \ \end{bmatrix}$;
- $z$就是由上一步最近邻得到的下标矩阵$\begin{bmatrix} 0 & 1 \ 2 & 2 \ \end{bmatrix}$,也是离散化的结果。这里可以看出$z_q(x)$并不是在$q(z|x)$这个分布上采样得到,而仅仅是$z_e(x)$的一个恒等映射。这恰恰说明了VQ-VAE为什么不能说是一个VAE,而更像一个AE。
- codebook是什么?还记得我们维护一个embedding Space和生成的离散化z矩阵吗?我们可以通过embedding Space,把z中离散化的下标整数翻译成一个个embedding Space空间中的向量,是不是很像一个密码本,这就是为什么叫embedding Space为codebook的原由。
Decoder部分:通过CNN将$z_q(x)$解码回图像,与AE的Decoder过程类似。
离散化的编码z似乎并没参与,好像有没有无所谓,但是为什么要做这一步呢,事实上离散化编码并不会参与到训练的计算图中,也就谈不上参与梯度的计算。前面我们都说了$z_e(x) \rightarrow z_q(x)$本身就是一个恒等映射的关系,你可以认为离散编码z只不过在求这个映射的额外产物而已。但是为什么要有呢,你可别忘了,最终我们在生成图像的时候,只是用到decoder,要有一个采样过程,而这个采样就是在z上随机采样。那你会问在z上采样不就是在embedding上采样吗,完全可以不用z啊?如果你能问出这样的问题,VQ-VAE你是真的懂了。理论上是这样,完全没错。理论上我们可以随机在z上也就是在embedding上通过均匀分布(作者假设先验分布是均匀分布)采样得到$z_q(x)$,但是效果很差,原因是什么,原因就是VQ-VAE是一个AE啊,本身就不具备生成的能力,即使离散化的z也是有一定的意义的。那怎么办的,还需要请出pixelcnn模型,既然z是有意义的,巧了pixelcnn专门干这事的,通过训练一个pixelcnn自然就得到一个有意义的z,自然就得到$z_q(x)$,就可以decoder生成图像了。
1.2 为什么要和VAE扯上关系
直接把VAE中的loss拿过来:
$$L(\theta,\phi;x^{(i)})=-KL(q_{\phi}(z|x^{(i)})||p_{\theta}(z))+E_{q_{\phi}(z|x^{(i)})}[ log(p_{\theta}(x^{(i)}|z)]$$
作者做了两个假设:第一个是先验分布$p_{\theta}(z)=1/K$是一个均匀分布,第二个是编码网络得到结果$q_{\phi}(z|x^{(i)})$为one-hot。关于one-hot是什么意思呢,做过分类任务都知道,one-hot编码有一个特点,就是等于这个编码结果概率为1,否则为0。
那么我们再来看KL散度这一项:
$$KL(q_{\phi}(z|x^{(i)})||p_{\theta}(z))=\sum_zq_{\phi}(logq_{\phi}-logp_{\theta})=0+0+0+...+1*(log(1)-log(1/K))=logK$$
从上式我们发现KL散度竟然是常数,loss那一项就剩下后面:
$$L(\theta,\phi;x^{(i)})=E_{q_{\phi}(z|x^{(i)})}[ log(p_{\theta}(x^{(i)}|z)]= log(p_{\theta}(x^{(i)}|z)$$
这大概就是作者为什么要这样强行跟VAE扯上关系本质原因。这也解释了第二个问题,看过VAE的都知道,上面的结果是等价于重建的损失的。但是只有一个重建loss似乎不够,别忘了我们还训练了一个codebook。接下来我们继续看完整的损失。
2. 完整的损失函数
我们知道,VAE中也推导也用了最大释然对数等价于均方差构建的重建损失,原因在于假设了高斯分布的情况下,因此可以用重建损失来代替,这也是为什么跟VAE扯上关系的原因,如果认为他是一个AE,那么直接用AE的重建损失就完事了:
$$L=||x-decoder(z_q(x))||_2^2$$
猛一看这个损失没有什么问题,细看,你会发现,它没办法优化encoder和codebook,原因是因为$z_e(x) \rightarrow z_q(x)$是通过codebook映射过去的,没法反向传播。于是作者提出了一个直通的概念,反向传播的时候直接让$z_e(x)$的梯度等于$z_q(x)$的梯度,也是途中红色箭头所示,具体怎么做呢,如果是pytorch,可以用detach来从计算图中分离出来,具体如下:将$(z_q-z_e).detach$即可
$$L=||x-decoder(z_q(x))||_2^2=||x-decoder(z_e+sg(z_q-z_e))||_2^2 \=||x-decoder(z_e+(z_q-z_e).detach)||_2^2$$
这样的话,重建损失有了,encoder和decoder也可以优化了,但是codebook那一项没法优化。那该怎么优化codebook呢,我们知道$z_q$来源于$z_e$在codebook中最近邻球得的,因此好的codebook必然是$z_q$与$z_e$越近越好,自然而然就是下面损失:
$$L_e=||z_q-z_e||_2^2$$
这样的话$z_q$可以更新codebook的参数,$z_e$可以更新encoder的参数,但是作者觉得编码器和codebook学习速度应该不一样快。于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。用$\beta$来平衡,常取0.1~2.0。作者取值为0.25。于是整体损失变为以下:
$$L=||x-decoder(z_e+sg(z_q-z_e))||_2^2 +||z_q-sg(z_e)||_2^2+\beta ||sg(z_q)-z_e||_2^2$$
至此,我们整个VQ-VAE就讲完了,看完之后是不是很舒服。包括我们如何去训练也一目了然。