论文分析|高效长文本生成的技术与应用
论文分析|高效长文本生成的技术与应用
随着Transformer模型尺寸和复杂性的增长,它们在训练期间的内存需求呈指数级增加。为了解决这一问题,一篇最新前沿论文提出了一种名为MINI-SEQUENCE TRANSFORMER(MST)的技术,用于高效准确地训练极长序列的大型语言模型(LLMs),通过减少中间内存使用,实现了显著的内存节省,而不影响模型性能或训练收敛速度。
前言
目前大模型公司很多在追求长文a本, 对算力需求极大,如何能够现实地处理该问题很重要。特别是随着Transformer模型尺寸和复杂性的增长,它们在训练期间的内存需求呈指数级增加。
语言模型训练的瓶颈在于显存占用非常大,这需要创新的解决方案来优化内存使用,同时保持性能。
本次将介绍一篇最新前沿论文,提出了一种名为MINI-SEQUENCE TRANSFORMER(MST)的技术,用于高效准确地训练极长序列的大型语言模型(LLMs),通过减少中间内存使用,实现了显著的内存节省,而不影响模型性能或训练收敛速度。MST方法通用、易于集成,并且支持分布式训练。
论文链接 🔗
https://arxiv.org/abs/2407.15892
贡献者:
Cheng Luo, Jiawei Zhao, Zhuoming Chen,
Beidi Chen, Anima Anandkumar
前沿背景
训练时必须需要在显存里存储以下内容:
1. 权重:模型的参数,包括所有层的权重矩阵,需要在训练前加载到显存中。
2. 正向传播时需要储存:
a. 激活值(Activations):计算并存储每层的激活,这些值需要保存以便于在反向传播时计算梯度。
b. **中间值(Intermediate Values)——计算时每一层时都需要储存:在模型的不同层,特别是多头自注意力(Multi-Head Attention)层和多层感知器(MLP)层中,计算过程中会产生中间值,如Q(Query)、K(Key)、V(Value)张量,以及MLP层的中间线性变换结果。
拆解开来主要是:
- Transformer 中的 Attention 层,计算 QKV 和注意力矩阵;
- Transformer 中的 MLP 层,将序列嵌入维度放大再缩小;
- LM Head (Language Modeling Head)将嵌入映射为 logits,之后计算 loss / 下一个 token。
3. 反向传播时需要计算并存储:梯度(Gradients)——在反向传播过程中计算得到的模型权重的梯度,用于更新模型参数,保持优化器的状态等等。
技术创新
文章提出了MST方法:多个小序列(Mini-sequences)迭代处理
核心思路:通过将输入序列划分为多个小序列(mini-sequences),并迭代处理这些小序列,从而减少了中间内存的使用。
MST方法关注的是计算过程中间状态,因为Transformer 在计算过程中会产生非常巨大的中间状态:
直接计算 Attention 会产生 NxN 大小的注意力矩阵
MLP 层也需要将嵌入维度放大四倍左右
LM Head 的映射会产生几万到几十万维度的矩阵,因为要计算整个词表中每个单词的概率
但与此相对的是 GPU 的显存是有限的,而且 GPU 的显存分了不同层次,速度快的 HBM 非常小,搬运到低层次的 SRAM 很费时间。因此降低中间状态的大小就非常重要。所幸语言模型的计算在token 之间依赖比较小,因此可以并行操作,这就导向一个关键技巧,也就是分块处理。
假设有 16 个 token 要处理,每个 token 处理时需要X的内存,如果一次性处理,那就需要 16X 的内存;但如果分两块处理,那么整个过程最多只需要 8X 的内存;第一次计算完成后,相应内存就可以释放了。当然,不是分块越多越好,分块越多,各种搬运记录计算结果,汇总结果所需的时间也会越长。
FlashAttention 其实就在使用这个技巧,因为 Attention 的中间状态太大了,随着序列长度二次方增长谁也受不了。随着 Attention 的中间状态被 FlashAttention 和 Ulysses 打下来,我们自然就盯上其他中间状态了。
本文就在讨论分块解决 Llama 3 中 MLP 层和 LM Head 的中间状态。
分解过程都是类似的,都是 分解、计算、汇总 。设序列长度为 N, 嵌入维度为 d,序列嵌入 X∈R^(Nd)那么就将 X 沿着分成 M 块,每块计算结果,然后汇总。
MLP层:Llama3 的 MLP 层有三个线性层 Wgate,Wup,Wdown。前向计算过程为:
• 首先将 Nd 大小的序列矩阵通过 Wgate, Wup 放大为两个 Nl 的矩阵
• 然后两个矩阵逐元素点乘,最后通过 Wdown 缩小为原来的 Nd 大小。
汇总的过程就很简单,直接在序列维度拼接就可以。
LM Head 计算 : 计算loss 的过程是将序列矩阵右乘一个矩阵映射到词表,也就是得到 NxV 的矩阵,然后算交叉熵。此时需要检查标签是否在块的范围内。汇总的过程是将 loss 加和再除以 M 取平均。
MST 没有改变 FLOP,但会增加 HBM 访问次数,标准 MLP 的访问次数为 Θ (Nd+NI+dI),而 MST 的访问次数为 Θ (Nd+NI+dIM)。标准 LM Head 的次数是 Θ (Nd+NI+dV),而 MST 下这个数字是** Θ (Nd+NI+dVM)。如果序列长度很短,中间状态占主导, dI,dV **占主导,MST会降低 GPU 显存的访问量。
研究结果
论文通过实验评估了MST在不同模型上的性能,包括Llama2和Llama3,并展示了在不同序列长度和批处理大小下的训练时间和内存开销。
- Loss 基本没有损失,Llama 3 8B 上最大序列长度是基本实现的12-20倍,激活重计算的1.8-4倍;
- 计算速度有些许损失;
- 峰值内存下降了很多。
结果分析:
- MST通用性和易用性:MST是完全通用的,并且可以轻松集成到现有的LLM训练框架中,只需最小的代码更改。
- 分布式训练支持:MST可以与DeepSpeed-Ulysses等技术一起工作,支持按GPU数量线性扩展序列长度。
- MST的局限性和未来的发展方向:例如将方法编译到CUDA以提高性能和内存效率,以及在小序列训练中可能出现的性能下降问题。
后续思考
- 模型规模的增大和词表的扩展(大词表):传统的串行计算方法可能无法满足效率需求,长序列下模型训练和推理会出现新的瓶颈,需要从硬件出发设计算法进行并行优化,如利用现代硬件(如GPU、TPU)的并行处理能力,可以显著加速模型的训练和推理过程。
- 分块计算:减少中间状态的思想非常简单但也非常好用,可以优化优化器的中间状态,值得继续推广
💡一点思考:长文本的能力是否值得长期探索?
- 一方面,长文本本身的应用可能并没有那么多或者需求并不强烈,如debug 代码库里的代码,但是在多模态的发展上,处理long sequence确实十分重要。
- 另一方面,长文本的能力在一定程度上反应智能的水平。人脑限制在20瓦,在这个限制下自然会放弃对一些记忆细节的把握,但是机器学习模型没有这方面的限制,因此没有必要过于参考人类生理进化的过程。eg:如果大模型能够同时记住所有学科的很多细节,那么很可能在科学研究上的突破性贡献要强于人类。