Wanda:一种简单有效的LLM剪枝方法
Wanda:一种简单有效的LLM剪枝方法
近年来,随着大模型规模的持续扩大,模型剪枝作为降低计算成本的重要手段备受关注。然而,现有的剪枝方法要么需要昂贵的再训练过程,要么依赖复杂的二阶信息重建。针对这一痛点,2023年提出了一种名为Wanda的新型剪枝方法,通过结合权重和激活值的信息,实现了无需再训练的高效剪枝。
现有剪枝方法的挑战
大模型的快速发展带来了巨大的计算负担,模型量化和模型剪枝是两种主要的压缩方法。其中,模型剪枝通过移除特定权重来减小网络规模,但现有方法普遍存在以下问题:
- 传统网络剪枝:需要将特定权重设置为零,通常需要再训练。
- 幅度剪枝:基于权重的绝对值进行剪枝,但可能误剪重要特征。
- 再训练需求:许多方法需要剪枝后对模型进行再训练。
- 从随机初始化开始训练:需要从随机初始化开始进行剪枝和训练。
- 复杂的迭代过程:如Lottery Ticket Hypothesis需要多次剪枝和再训练。
- SparseGPT:虽然不需要传统再训练,但需要复杂的权重更新过程。
Wanda方法的核心原理
Wanda的创新在于结合了权重和激活值的信息。具体来说,Wanda在计算权重和输入激活的乘积后,选择最小的权重进行剪枝。这种方法能够更准确地识别并保留那些尽管权重较小但因激活值较大而对模型性能至关重要的特征。
左图:Magnitude Pruning选取50%最小的权重并全部赋值0;右图:Wanda在计算权重和激活相乘后的结果,再选取50%最小的位置并将对应位置的权重赋值为0。
Wanda的实现代码
以下是Wanda的PyTorch实现代码:
def prune(W, X, s):
metric = W.abs() * X.norm(p=2, dim=0)
_, sorted_idx = torch.sort(metric, dim=1)
pruned_idx = sorted_idx[:, :int(C_in * s)]
W.scatter_(dim=1, index=pruned_idx, src=0)
return W
输入参数:
W
:权重矩阵,形状为(C_out, C_in)
X
:输入矩阵,形状为(N * L, C_in)
s
:所需的稀疏度,值在0和1之间计算剪枝指标:
计算每个权重的绝对值乘以对应输入激活的L2范数,得到剪枝指标
排序和选择要剪除的权重:
对指标进行排序
选择需要剪除的权重索引
执行剪枝:
根据选择的索引将相应的权重设置为零
返回剪枝后的权重矩阵:
结构化稀疏性
Wanda最初是为非结构化稀疏性设计的,但也可以轻松扩展到结构化稀疏性。在结构化稀疏性中,每M个连续权重中最多只有N个权重是非零的。这种稀疏模式可以利用NVIDIA的稀疏张量核心加速矩阵乘法。
实验评估
Wanda在LLaMA系列模型(7B/13B/30B/65B)和LLaMA-2系列模型(7B/13B/70B)上进行了评估,使用WikiText数据集的验证集评估困惑度。实验主要关注线性层的剪枝效果,跳过了第一个嵌入层和最后一个分类层(这些层占总LLM参数量的约99%),并施加了统一的稀疏性设置,包括非结构化稀疏性和结构化的4:8、2:4稀疏性。
实验结果表明,Wanda在保持模型性能的同时,显著降低了计算成本,为大模型的部署和应用提供了新的可能性。