Diffusion2GAN:将扩散模型的知识提炼为条件GAN
Diffusion2GAN:将扩散模型的知识提炼为条件GAN
Diffusion2GAN是一种创新的框架,它将预先训练好的多步扩散模型提炼成一个单步条件生成对抗网络(GAN)。这种方法通过将生成建模任务分解为识别对应关系和学习映射两个部分,实现了性能和计算效率的双重提升。
介绍
扩散模型在LAION等高难度数据集上展现了卓越的图像生成质量,但高质量的结果往往需要数十到数百个采样步骤,导致延迟较高,难以实现实时交互。如果能提供一步式模型,它们将改善文本到图像生成的用户体验,并可用于三维和视频应用。
一个简单的解决方案是从头开始训练一个一步模型:虽然在简单领域中,GAN是一个很好的一步模型,但在大型和多样化数据集上生成文本到图像时,仍然存在挑战。这一挑战是由于需要在无监督的情况下执行两项任务:寻找噪声与自然图像之间的对应关系,以及优化噪声到图像的映射。
本评论文章的想法是逐一解决这些任务。首先,在预先训练的扩散模型中找到噪声与图像之间的对应关系,然后让条件GAN在配对图像转换框架中完成噪声与图像之间的映射。这种方法结合了扩散模型的高质量和条件GAN的快速映射的优点。
在实验中,利用提出的Diffusion2GAN框架,将稳定扩散1.5提炼为单步条件GAN模型。可以更好地学习。此外,它的表现优于UFOGen和DMD,并在COCO2014基准测试中取得了优异成绩。特别是在SDXL蒸馏法中,它的表现优于SDXL-Turbo和SDXL-Lightning。
算法框架
总体概述
图 1:Diffusion2GAN 概述。
图1是整个拟议方法的概览。首先,收集扩散模型的输出潜变量以及输入噪声和提示。然后使用E-LatentLPIPS loss和GAN loss训练生成器,将噪声和提示映射到目标潜变量。生成器的输出可以解码为RGB像素,这是一项计算成本很高的操作,在训练过程中不会执行;E-LatentLPIPS损失和GAN损失将在下一节详细讨论。
E-LatentLPIPS损失
传统知识蒸馏的损耗如下式所示。该损耗可按原样使用(图2-b),但由于它是为像素空间设计的,因此需要在解码器中将其从潜空间改为像素空间。
由于这种操作的计算成本很高,因此需要一种方法来直接计算潜空间中的感知距离,而无需对像素进行解码。
为此,按照Zhang等人的方法,在稳定扩散潜空间的ImageNet上训练了VGG网络。该潜在空间已经进行了8倍的下采样,因此架构略有改动。然后使用BAPPS数据集对中间特征进行线性校准,以获得在潜空间中工作的函数。
dLatentLPIPS(x0,x1)=ℓ(F(x0),F(x1))
图2:在重建图像上验证所提损失函数的有效性。
不过,可以看到,直接应用LatentLPIPS作为蒸馏的新损失函数会产生波浪状的斑块假象,如图2-c所示。
受E-LPIPS的启发,随机可微分扩展、一般几何变换和剪切被应用于生成潜变量和目标潜变量。在每个纪元,生成潜变量和目标潜变量都会被随机扩张。当应用于单图像优化时,集合策略几乎能完美地重建目标图像(见图2-d)。新的损失函数简称为Ensembled-LatentLPIPS或E-LatentLPIPS,其中T是随机取样的扩展。
条件扩散判别器
在上一节中,我们已经证明,扩散蒸馏可以看作是一种成对噪声到潜变量的转换任务。受条件GANs对图像到图像配对转换效果的启发,我们使用了条件判别器。该判别器的条件不仅包括文本描述c,还包括提供给生成器的高斯噪声z。新的判别器结合了上述条件,同时利用了预先训练好的扩散权重。具体来说,生成器G和判别器D的最小-最大目标函数如下。判别器概览见图3。
图3.多尺度条件判别器的设计
测试试验
与蒸馏扩散模型进行比较
将Diffusion2GAN与最先进的扩散蒸馏模型进行比较,COCO2014和COCO2017的结果如表1和表2所示:InstaFlow-0.9B在COCO2014中实现了13.10的FID,在COCO2017中实现了23.4的FID,而Diffusion2GAN则达到FID9.29和FID19.5。还使用了其他扩散模型,但Diffusion2GAN在训练时紧跟原始模型的轨迹,在保持高视觉质量的同时减轻了多样性崩溃问题,而ADD-M使用ViT-g-14文本编码器,并获得了较高的CLIP-5k分数、Diffusion2GAN在训练时没有使用该编码器。
表1:与2014年COCO会议最新文本到图像模型的比较
表2.COCO2017与近期文本到图像模型的比较
视觉分析
图4.与Stable Diffusion 1.5教师的直观比较
图4直观地比较了拟议方法与Stable Diffusion 1.5、LCM-LoRA和InstaFlow。由于随着无分类器引导(CFG)规模的扩大,扩散模型往往能生成更逼真的图像,因此建议的方法使用SD-CFG-8数据集训练Diffusion2GAN,并使用相同的引导规模(8)将其与Stable Diffusion 1.5进行比较;对于LCM-LoRA和InstaFlow,则采用各自的最佳设置,以确保比较的公平性。拟议方法的结果表明,与其他蒸馏基线相比,它生成的图像更加逼真,同时还保留了稳定扩散教师生成的目标图像的整体布局。
训练速度
即使算上ODE数据集的准备成本,Diffusion2GAN也比现有的蒸馏方法收敛得更有效:在CIFAR10数据集上,比较了生成器网络在整个训练期间的函数评估总数。建议的方法可以证实,使用LPIPS loss进行的训练已经超过了一致性蒸馏法在500k监督输出下的FID(表3)。在文本到图像的合成方面,Diffusion2GAN的完整版本比InstaFlow实现了更好的FID,而且GPU日数也少得多(表4)。
表3.CIFAR的收敛性比较10
表4.所需计算资源与图像质量的比较
总结
在这项工作中,我们提出了一个新的框架Diffusion2GAN,将预先训练好的多步扩散模型提炼成一个根据条件GAN和感知损失训练好的单步生成器。所提出的方法表明,将生成建模分成两项任务(识别对应关系和学习映射)可以提高使用不同生成模型的性能和计算效率。这种简单的方法不仅能改善交互式图像生成,还能提高视频和三维应用的效率。
然而,所提出的方法有几个局限性:第一,它使用固定的指导尺度,这意味着在推理过程中不能使用不同的值;第二,它依赖于教师模型的质量,这限制了它的性能;第三,增加学生和教师模型的尺度仍会导致多样性多样性的退化。如果能解决这三个问题,未来的研究将使该模型更加实用。