问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

LLM量化新篇章:4-bit权重激活量化几乎无损!FlatQuant的平坦之道

创作时间:
作者:
@小白创作中心

LLM量化新篇章:4-bit权重激活量化几乎无损!FlatQuant的平坦之道

引用
1
来源
1.
https://aijishu.com/a/1060000000488895

本文介绍来自华为诺亚方舟实验室、清华大学和香港中文大学联合在大语言模型量化上的最新工作FlatQuant(Fast and Learnable Affine Transformation)。FlatQuant通过为每个线性层适配轻量的可学习的仿射变换,有效平滑LLM离群值,得到更加平坦的权重和激活值分布,有效降低量化损失。相比此前的量化方法,本方法首次在LLaMA-3-70B上达到W4A4<1%的精度损失,并可带来最高2.3x prefill和1.7x decoding加速比。

1. 大语言模型 (LLM) W4A4 量化问题

模型量化是大语言模型 (LLM) 推理加速的常用技术,可以通过将权重和激活值同时压缩到低比特来有效降低访存开销,并利用峰值算力更高的INT4/8 Tensor Core完成矩阵运算,从而带来实际的推理加速比。

然而,目前的W4A4(权重4位,激活值4位)量化模型相比全精度模型还存在着较大的量化损失,难以在实际应用中使用,也就难以利用峰值算力最高的INT4 Tensor Core加速LLM的实际推理部署。我们发现,量化前权重和激活值分布的平坦度(flatness)是影响LLM量化误差的关键因素。

直观来看,分布越平坦,离群值就越少,量化时的精度也就越高。已有方法大多使用pre-quantization transformations,通过在量化前对权重和激活值做等价变换得到更平坦的分布来降低量化误差,常用的变换主要有Per-channel Scaling [1]和Hadamard变换 [2]。

FlatQuant推动W4A4 LLM部署

然而,我们发现这些变换并不是最优的,为此我们提出FlatQuant(Fast and Learnable Affine Transformation),为每个线性层学习一个最优的仿射变换来有效缓解权重和激活值上的离群值,从而得到平坦的权重和激活值分布,有效提升了量化精度。此外,针对推理中的在线变换,我们进行了算子融合进一步降低访存开销,使得在线变换仅带来极小的推理开销。

实验表明,FlatQuant在W4A4的设置下极大地减少了量化模型的精度损失,甚至在部分模型上达到了接近无损的效果(e.g. LLaMA-3-70B),轻量的在线变换也使得FlatQuant能达到2.3x的prefill和1.7x的decoding加速比。我们希望FlatQuant能进一步推动W4A4 LLM的实际部署,从而更加有效地降低LLM的推理成本。

2. 探索平坦分布与量化损失的优化路径

LLM的权重和激活值上存在较多的离群值,特别是激活值上常常存在离群值通道(outlier channels),导致LLM难以量化。目前针对LLM WA量化的方法大多在量化前对权重和激活值做等价变换来用其他通道吸收离群值,从而得到更加平坦的分布以降低量化损失。例如:

为降低理解的难度,NeuralTalk在此举一个例子。

权重和激活值分布可以看作两个斜坡,变换就类似于用铲子搬土,土不会凭空增加或者减少,所以目标是通过把两个坡中高处(离群值)的土填到低处(非离群值通道),从而把这两个坡填平。

  • Per-channel Scaling就相当于只能把一个坡上的土填到另一个坡的相同位置上,比较局限。
  • Hadamard变换相当于在每个坡的内部把高处的土填到自身的低处,但不能在两个坡之间转移土。并且由于不同坡的形状不同,相同的Hadamard变换(坡内搬土方式)不一定适用于所有土坡。

相比之下,FlatQuant方法可以被看作是一种更加精细和智能的“搬土”策略。在这个方法中,我们不再局限于只在单个斜坡内部移动土,也不只是在两个斜坡的相同位置上进行土的转移。相反,FlatQuant允许我们对每个斜坡进行定制化的调整,这意味着我们可以针对每个斜坡的独特形状和需求,设计出最佳的“搬土”方案。

这就相当于为模型的每一层学习一个特定的仿射变换,以得到更平坦的分布,并且可以自适应地平衡权重和激活值的量化难度。

2.1 平坦分布的追求与挑战

在下图1中,我们画出了LLM的不同权重和激活值在变换前后的分布情况,理想情况下,我们希望能利用所有通道吸收离群值,使得变换后的分布呈现一条平坦的水平线。

图1:等价变换前后模型的权重和激活值分布,具体来说,按通道幅值(即Frobenius范数)降序排列的LLaMA-3-8B和LLaMA-3-70B的权重和输入的分布情况。

注:在Transformer层中,Wo和Xo分别表示自注意力层输出投影层的权重矩阵和输入。Wg和Xg分别表示前馈网络中门控线性层的权重和输入。更多的可视化内容可以在文章附录D中找到。四个图分别是:(a) LLaMA-3-8B的第10层Transformer的Wo。(b) LLaMA-3-8B的第10层Transformer的Xo。(c) LLaMA-3-70B的第30层Transformer的Wg。(d) LLaMA-3-70B的第30层Transformer的Xg。

但如上图1所示,我们发现已有的等价变换得到的分布仍然可能是不平坦的:

  • Per-channel Scaling,离群值仍然被限制在了权重和激活值的相同通道上,非离群值通道得不到有效利用,因此不管是权重还是激活值,变换后的分布都非常陡峭,呈现出非常明显的离群值通道。
  • Hadamard变换对所有权重和激活值都施加相同的变换,而不同层的权重和激活值分布是不同的,这意味着Hadamard变换并不是对于每个层的最优解,例如图1(a)(b)中,LLaMA-3-8B的权重和激活值经过Hadamard变换后仍然比较陡峭,特别是激活值上的离群值无法得到有效平滑。此外,Hadamard变换作为一种正交变换不会改变向量的模长,而LLM激活值上大量的离群值会导致激活值模长显著大于权重,这导致正交变换后的激活值量化难度也会显著高于权重,无法像Per-channel Scaling一样灵活地平衡权重和激活值上的量化难度。

相比之下,FlatQuant通过给每一层针对性地学习仿射变换,不仅可以得到平坦的分布,还可以自适应地平衡权重和激活值的量化难度。

2.2 不同等价变换下的量化损失平面

下面的图2中,我们画出了不同变换后LLM的量化损失平面,可以发现,per-channel scaling和Hadamard变换都无法很好处理具有massive outlier [3]的关键词元(pivot token),导致在首词元上具有非常大的量化误差,已有研究表明关键词元上的量化误差会比较严重地影响模型的量化精度[4]。

图2:不同等价变换下的量化损失平面。明显看出FlatQuant方法的MSE更小。

相比之下,FlatQuant则可以显著降低关键词元上的量化损失,并有效抑制量化误差的逐层传播,带来更加平坦的量化损失平面。

2.3 方法概述

轻量仿射变换

Kronecker Decomposition

Per-channel Scaling

Learnable Clipping Thresholds

我们对变换后的权重和激活值进一步采用了learnable clipping来更好地消除离群值。

以上就是关键方法,分步来说:

  1. 轻量仿射变换:通过学习每个线性层的最优仿射变换来平滑离群值。
  2. Kronecker分解:将大的变换矩阵分解为小矩阵,减少存储和计算开销。
  3. Per-channel Scaling:为每个通道提供独立的缩放因子,增加变换的灵活性。
  4. Learnable Clipping Thresholds:通过可学习的裁剪阈值进一步减少离群值的影响。

优化过程

损失函数采用Layer-wise MSE loss:

模型架构

如图下图3所示,FlatQuant在单个Transformer内会引入5种不同的在线变换,对于LLaMA-2-7B,这些在线变换在序列长度2K时的FLOPs仅为FP16模型的2.61%,对在线变换中两个小矩阵乘以及量化操作的算子融合还可以帮助进一步降低FlatQuant的额外推理开销。

图3:FlatQuant模型架构图

另外注意到,在QuaRot [2]和SpinQuant [5]中,为了降低在线推理开销,MHA / MLP输入处的正交变换会被融合到前序线性层里,但由于残差连接的限制,不同Transformer block中的MHA / MLP都必须共享输入处的正交变换,这不仅限制了变换的灵活性,还使得在优化变换矩阵时必须采用端到端优化,需要较大的训练开销。

相比之下,FlatQuant不仅可以对每个线性层都学得最适配的等价变换,还可以逐层优化,仅需单卡即可完成对70B模型的量化。

3. 实验结果

量化设置. 实验中,我们保持了与QuaRot [2]相同的量化设置,权重和激活值分别采用per-channel和per-token对称量化,KV cache量化采用group-wise非对称量化(g128),校准集为来自WkiText-2数据集的128条样本。

3.1 量化精度

我们测试了W4A4下量化模型的PPL和QA任务上的精度结果,从表1和表2中可以看到,FlatQuant在使用RTN作为weight quantizer时精度就已经能比较明显地超过QuaRot和SpinQuant使用GPTQ的效果。

  • 对于较大的13B/70B模型,QA精度损失均在1%左右。
  • 更小的7B/8B模型的精度损失也维持在了2%左右。
  • FlatQuant对于更难量化的LLaMA-3模型提升尤为明显,例如LLaMA-3-70B的QA任务上FlatQuant相比SpinQuant有超过7%的精度提升,同时与全精度模型的精度差距保持在1%以内。

表1:W4A4 PPL实验结果

表2:W4A4 zero-shot QA任务实验结果

3.2 端到端加速比

我们在RTX3090上测试了FlatQuant的prefill/decoding端到端加速比。如图4所示,FlatQuant最高能带来2.30x的prefill和1.76x的decoding加速比,推理速度超过了QuaRot,相比INT4也仅有极小的加速比损失。

图4:端到端加速比

3.3 更多实验

(1) 消融实验. 从表3中可以看到,在RTN量化的基础上加入LT(Learnable Transformation)就已经能极大地提升量化模型精度,进一步加入PS(Per-channel Scaling)和LCT(Learnable Clipping Thresholds)还能进一步提升模型精度。

表3:LLaMA-3-8B消融实验

(2) 权重量化. FlatQuant在权重量化上也能与SOTA的uniform量化方法达到相当的精度。

(3) Train One and Get More. FlatQuant中W4A4量化设置下学到的变换矩阵可以直接用在其他量化设置下,这使得我们能更加便利地在不同量化设置下使用FlatQuant。

表5, 6:更多量化设置

4. 总结

现有的量化方法在W4A4下量化损失大难落地。量化前权重和激活值分布的平坦度显著影响量化误差。FlatQuant通过为每个线性层适配轻量的可学习的仿射变换,平滑权重和激活值上的离群值,得到更平坦的分布解决该问题。

在LLaMA-3-70B模型上实现小于1%的量化损失。部分模型接近无损效果。性能上FlatQuant带来高达2.3倍的prefill加速和1.7倍的decoding加速。

总的来说,FlatQuant是一种创新的量化方法,它通过学习最优的仿射变换来提高LLM量化精度,保持高加速比的同时,显著降低量化损失。这项工作对于推动大型语言模型在实际应用中的部署具有重要意义。

5. 参考文献

  • [1] Xiao, Guangxuan, et al. “Smoothquant: Accurate and efficient post-training quantization for large language models.” International Conference on Machine Learning. PMLR, 2023.
  • [2] Ashkboos, Saleh, et al. "Quarot: Outlier-free 4-bit inference in rotated llms." arXiv preprint arXiv:2404.00456 (2024).
  • [3] Sun, Mingjie, et al."Massive Activations in Large Language Models."arXiv preprint arXiv:2402.17762 (2024).
  • [4] Liu, Ruikang, et al."IntactKV: Improving Large Language Model Quantization by Keeping Pivot Tokens Intact."arXiv preprint arXiv:2403.01241(2024).
  • [5] Liu, Zechun, et al."SpinQuant--LLM quantization with learned rotations."arXiv preprint arXiv:2405.16406(2024).
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号