一文详解大模型微调常用方法
一文详解大模型微调常用方法
随着深度学习的发展,大型预训练模型如GPT-3、ChatGPT等在自然语言处理任务中展现出卓越性能。然而,这些模型的训练成本极高,需要大量计算资源和数据。为了解决这一问题,研究人员开始探索Parameter-Efficient Fine-Tuning (PEFT)技术,通过最小化微调参数的数量和计算复杂度,使得即使计算资源受限,也能利用预训练模型的知识快速适应新任务。本文将详细介绍几种常用的PEFT方法,包括Adapter Tuning、Prompt Tuning、Prefix Tuning、P-Tuning及其变种,以及AdaLoRA等。
Adapter Tuning
2019年,谷歌的研究人员首次在论文《Parameter-Efficient Transfer Learning for NLP》中提出了针对BERT的PEFT微调方式,拉开了PEFT研究的序幕。他们设计了如下图所示的Adapter结构,将其嵌入Transformer的结构里面,在训练时,固定住原来预训练模型的参数不变,只对新增的Adapter结构进行微调。
同时为了保证训练的高效性(也就是尽可能少的引入更多参数),他们将Adapter设计为这样的结构:
- 首先是一个down-project层将高维度特征映射到低维特征
- 然后过一个非线形层之后,再用一个up-project结构将低维特征映射回原来的高维特征
- 同时也设计了skip-connection结构,确保了在最差的情况下能够退化为identity(类似残差结构)。
从实验结果来看,该方法能够在只额外对增加的3.6%参数规模(相比原来预训练模型的参数量)的情况下取得和Full-Finetuning接近的效果(GLUE指标在0.4%以内)。
Prefix Tuning
2021年,斯坦福的研究人员在论文《Prefix-Tuning: Optimizing Continuous Prompts for Generation》中提出了Prefix Tuning方法。与Full-finetuning更新所有参数的方式不同,该方法是在输入token之前构造一段任务相关的virtual tokens作为Prefix,然后训练的时候只更新Prefix部分的参数,而Transformer中的其他部分参数固定。
同时,为了防止直接更新Prefix的参数导致训练不稳定的情况,他们在Prefix层前面加了MLP结构(相当于将Prefix分解为更小维度的Input与MLP的组合后输出的结果),训练完成后,只保留Prefix的参数。
Prompt Tuning
Prompt Tuning是2021年谷歌在论文《The Power of Scale for Parameter-Efficient Prompt Tuning》中提出的微调方法。该方法可以看作是Prefix Tuning的简化版本,只在输入层加入prompt tokens,并不需要加入MLP进行调整来解决难训练的问题,主要在T5预训练模型上做实验。似乎只要预训练模型足够强大,其他的一切都不是问题。作者也做实验说明随着预训练模型参数量的增加,Prompt Tuning的方法会逼近Fine-tune的结果。
固定预训练参数,为每一个任务额外添加一个或多个embedding,之后拼接query正常输入LLM,并只训练这些embedding。左图为单任务全参数微调,右图为Prompt tuning。
作者做了一系列对比实验,都在说明:随着预训练模型参数的增加,一切的问题都不是问题,最简单的设置也能达到极好的效果。
- Prompt长度影响:模型参数达到一定量级时,Prompt长度为1也能达到不错的效果,Prompt长度为20就能达到极好效果。
- Prompt初始化方式影响:Random Uniform方式明显弱于其他两种,但是当模型参数达到一定量级,这种差异也不复存在。
- 预训练的方式:LM Adaptation的方式效果好,但是当模型达到一定规模,差异又几乎没有了。
- 微调步数影响:模型参数较小时,步数越多,效果越好。同样随着模型参数达到一定规模,zero shot也能取得不错效果。
- 当参数达到100亿规模与全参数微调方式效果无异。
P-Tuning v1
P-Tuning方法的提出主要是为了解决这样一个问题:大模型的Prompt构造方式严重影响下游任务的效果。P-Tuning提出将Prompt转换为可以学习的Embedding层,只是考虑到直接对Embedding参数进行优化会存在这样两个挑战:
- Discretenes:对输入正常语料的Embedding层已经经过预训练,而如果直接对输入的prompt embedding进行随机初始化训练,容易陷入局部最优。
- Association:没法捕捉到prompt embedding之间的相关关系。作者在这里提出用MLP + LSTM的方式对prompt embedding进行一层处理:
P-tuning依然是固定LLM参数,利用多层感知机和LSTM对Prompt进行编码,编码之后与其他向量进行拼接之后正常输入LLM。注意,训练之后只保留Prompt编码之后的向量即可,无需保留编码器。
P-Tuning v2
P-Tuning的问题是在小参数量模型上表现差(如上图所示)。于是就有了v2版本:《P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks》。
从标题就可以看出,P-Tuning v2的目标就是要让Prompt Tuning能够在不同参数规模的预训练模型、针对不同下游任务的结果上都达到匹敌Fine-tuning的结果。那也就是说当前Prompt Tuning方法在这两个方面都存在局限性。
- 不同模型规模:Prompt Tuning和P-tuning这两种方法都是在预训练模型参数规模够足够大时,才能达到和Fine-tuning类似的效果,而参数规模较小时效果则很差。
- 不同任务类型:Prompt Tuning和P-tuning这两种方法在sequence tagging任务上表现都很差。
主要结构
相比Prompt Tuning和P-tuning的方法,P-tuning v2方法在多层加入了Prompts tokens作为输入,带来两个方面的好处:
- 带来更多可学习的参数(从P-tuning和Prompt Tuning的0.1%增加到0.1%-3%),同时也足够parameter-efficient。
- 加入到更深层结构中的Prompt能给模型预测带来更直接的影响。
v1到v2的可视化:蓝色部分为参数冻结,橙色部分为可训练部分。
几个关键设计因素
- Reparameterization:Prefix Tuning和P-tuning中都有MLP来构造可训练的embedding。本文发现在自然语言理解领域,面对不同的任务以及不同的数据集,这种方法可能带来完全相反的结论。
- Prompt Length:不同的任务对应的最合适的Prompt Length不一样,比如简单分类任务下length=20最好,而复杂的任务需要更长的Prompt Length。
- Multi-task Learning多任务对于P-Tuning v2是可选的,但可以利用它提供更好的初始化来进一步提高性能。
- Classification Head使用LM head来预测动词是Prompt Tuning的核心,但我们发现在完整的数据设置中没有必要这样做,并且这样做与序列标记不兼容。P-tuning v2采用和BERT一样的方式,在第一个token处应用随机初始化的分类头。
实验结果
- 不同预训练模型大小下的表现,在小模型下取得与Full-finetuning相近的结果,并远远优于P-Tuning。
- 不同任务下的P-Tuning v2效果都很好,而P-Tuning和Prompt Learning效果不好;同时,采用多任务学习的方式能在多数任务上取得最好的结果。
AdaLoRA
预训练语言模型中的不同权重参数对下游任务的贡献是不同的。因此需要更加智能地分配参数预算,以便在微调过程中更加高效地更新那些对模型性能贡献较大的参数。具体来说,通过奇异值分解将权重矩阵分解为增量矩阵,并根据新的重要性度量动态地调整每个增量矩阵中奇异值的大小。这样可以使得在微调过程中只更新那些对模型性能贡献较大或必要的参数,从而提高了模型性能和参数效率。
Towards a Unified View of PETL
这篇ICLR2022的文章研究了典型的PEFT方法,试图将PEFT统一到一个框架下,找出它们起作用的具体原因,并进行改进。主要研究了三个问题:
- 典型的PEFT方法有什么联系?
- 典型的PEFT方法中是哪些关键模块在起作用?
- 能否对这些关键模块进行排列组合,找出更有用的PEFT方法?
通过对Prefix Tuning的推导,得出了和Adapter Tuning以及LoRA形式一致的形式。包括这几大要素:
- 的形式
- 嵌入Transformer结构的方式(分为Parrell和Sequential两种。Parallel指的是在输入层嵌入,这样与原有结构可以并行计算;Sequential指的是在输出层嵌入,相当于增加了网路的深度,与原有结构存在依赖关系)
- 修改表示层(主要指对attention层的修改还是对ffn层的修改)
- 组合方式。怎么与原有的参数组合,包括简单相加(Adapter)、门控式(Prefix Tuning)、缩放式(LoRA)三种)
根据这个统一的框架,还另外设计了三种变体Parallel Adapter、Multi-head Parallel Adapter、Scaled Parallel Adapter。