理解生成对抗网络(GAN)
理解生成对抗网络(GAN)
一、概念
生成对抗网络(GAN)是深度学习中一类重要的模型,由Ian J. Goodfellow等人于2014年提出。它是一种无监督学习模型,通过生成器(Generator)和判别器(Discriminator)的对抗训练过程,使生成器能够生成与真实数据难以区分的样本。
GAN的主要目标是让生成器学习到真实数据的分布,从而生成高质量的合成数据。在训练过程中,生成器试图生成能够欺骗判别器的样本,而判别器则努力区分真实数据和生成器产生的假数据。最终,当判别器无法准确区分两者时,说明生成器已经很好地学习到了数据分布。
二、GAN基本原理
1. 构成
GAN由两个核心组件构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。
生成器:通过机器生成数据,目的是尽可能“骗过”判别器。生成器的输入是一组随机噪声X(通常是一个随机向量),通过学习一个映射F,将随机噪声X映射为一个结果Y,这个结果就是我们想要生成的图片。生成器的目标是学习到训练数据的分布,使得生成的图片能够以假乱真。
判别器:判断数据是真实数据还是生成器生成的数据。判别器的输入可以是生成器的输出Y或真实的训练数据X,其输出是一个概率值,表示输入数据是真实数据的可能性。判别器的目标是尽可能准确地区分真实数据和生成数据。
2. 训练过程
GAN的训练过程可以分为两个阶段:
第一阶段:固定判别器D,训练生成器G。使用一个性能不错的判别器,G不断生成“假数据”,然后给这个D去判断。开始时候,G还很弱,所以很容易被判别出来。但随着训练不断进行,G技能不断提升,最终骗过了D。这个时候,D基本属于“瞎猜”的状态,判断是否为假数据的概率为50%。
第二阶段:固定生成器G,训练判别器D。当通过了第一阶段,继续训练G就没有意义了。这时候我们固定G,然后开始训练D。通过不断训练,D提高了自己的鉴别能力,最终他可以准确判断出假数据。
通过不断的循环,生成器G和判别器D的能力都越来越强。最终我们得到了一个效果非常好的生成器G,就可以用它来生成数据。
整个训练的过程为:
- 从高斯分布中采样一批次长度为n的噪声向量。
- 利用(1)中噪声向量,使用generator生成假图像。
- 从真实数据采一批次真实图像,与(2)中的假图像混合,做好标签,训练discriminator。
- 再从高斯分布中采样长度为n的一批次噪声向量,标签为“True”,训练GAN,此时GAN中的discriminator参数不能更新,只训练generator。
- 按指定轮数重复上述步骤。
3. GAN的损失函数
- 判别器D的优化目标:
解释:判别器越完美,越能区分数据来源,第一个期望越接近0,第二个期望越接近0。完美的判别器V(D)优化为0。
- 生成器G的优化目标:
解释:因为总损失函数中,第一个期望不含G,所以只用考虑第二个期望。生成器越完美,越能骗过判别器,第二个期望越接近负无穷。完美的生成器V(G)优化为负无穷。
判别器的损失:
- 判别器给真实图片打的分与其期望分数(1)的差距D_L1
- 判别器给生成图片打的分与其期望分数(0)的差距D_L2
- 则生成器的总损失为 D_L1 + D_L2
生成器的损失:
- 生成图片与真实图片的差距
- 实际上,将该差距转化为生成器期望判别器给自己生成图片打多少分与实际判别器打多少分的差距
4. GAN的优缺点
优点
- 能更好建模数据分布(图像更锐利、清晰)
- 理论上,GANs 能训练任何一种生成器网络。其他的框架需要生成器网络有一些特定的函数形式,比如输出层是高斯的
- 无需利用马尔科夫链反复采样,无需在学习过程中进行推断,没有复杂的变分下界,避开近似计算棘手的概率的难题
缺点
- 模型难以收敛,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。D/G 的训练需要精心的设计
- 模式缺失(Mode Collapse)问题。GANs的学习过程可能出现模式缺失,生成器开始退化,总是生成同样的样本点,无法继续学习