问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

模型压缩:剪枝、量化、蒸馏

创作时间:
2025-03-09 23:58:07
作者:
@小白创作中心

模型压缩:剪枝、量化、蒸馏

引用
CSDN
1.
https://blog.csdn.net/u010618499/article/details/137479217

模型压缩是深度学习领域中一个重要的研究方向,其目的是在保持模型精度的同时,减少模型的计算量和存储需求。本文将详细介绍三种主流的模型压缩方法:剪枝、量化和知识蒸馏,并通过PyTorch框架进行具体实现。

1. 剪枝(Pruning)

剪枝是一种通过减少模型中的参数量来进行模型压缩的技术,可以在保证一定精度的情况下减少模型的内存占用和硬件消耗。

PyTorch剪枝实现

PyTorch提供了多种剪枝规则,包括:

  • RandomUnstructured
  • L1Unstructured
  • RandomStructured
  • LnStructured
  • CustomFromMask

下面是一个使用PyTorch进行全局剪枝的示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import LeNet
from torch.nn.utils import prune

# 实例化神经网络
model = LeNet()

# 配置可剪枝的网络层和参数名
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

# 全局剪枝,采用L1Unstructured的方法,剪去0.2的参数量
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

# 查看剪枝后的模型参数
print(model.state_dict())

2. 量化(Quantization)

模型量化是一种压缩网络参数的方式,它将神经网络的参数(weight)、特征图(activation)等原本用浮点表示的量值换用定点(整型)表示,同时尽可能减少计算精度损失。在计算过程中,再将定点数据反量化回浮点数据,得到结果。

PyTorch量化实现

PyTorch提供了三种量化方式:

  • 量化感知训练(Quantization Aware Training, QAT)
  • 训练后动态量化(Post Training Dynamic Quantization)
  • 训练后校正量化(Post Training Static Quantization)

以下是具体的实现代码:

import torch.quantization

# 训练后动态量化
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# 训练后校正量化
deploymentmyModel.qconfig = torch.quantization.get_default_config('fbgemm')
torch.quantization.prepare(myModel, inplace=True)
torch.quantization.convert(myModel, inplace=True)

# 量化感知训练
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qat_model, inplace=True)
epochquantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)

量化方法分类

根据量化方案分类:

  • 在线量化(QAT):在量化的同时结合反向传播对模型权重进行调整
  • 离线量化(PTQ)
  • 训练后动态量化:不使用校准数据集
  • 训练后校正量化:需要输入有代表性的数据集

根据量化公式分类:

  • 线性量化:浮点数与定点数之间的转换公式为 Q = R / S + Z
  • 对称量化:量化前后的0点是对齐的
  • 非对称量化:量化前后0点不对齐,需要额外记录一个offset

Block-wise量化

为了避免异常值的影响,可以将输入tensor分割成一个个block,每个block单独做量化,有单独的scale和zero,从而减少量化的精度损失。

3. 知识蒸馏(Knowledge Distillation)

知识蒸馏是一种技术,可以将知识从计算成本较高的大型模型转移到较小的模型,而不会失去有效性。具体的方法是在训练小模型时,在损失函数中添加额外的损失函数。

损失函数添加方式

  1. 输出层的差异损失

optimizer.zero_grad()
ce_loss = nn.CrossEntropyLoss()

# 冻结教师网络的权重,计算教师网络的输出层结果
with torch.no_grad(): 
    teacher_logits = teacher(inputs) 

# 学生网络的输出层结果 
student_logits = student(inputs) 

# 计算软目标损失
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2) 

# 学生网络的交叉熵损失
label_loss = ce_loss(student_logits, labels) 

# 将软目标损失添加到交叉熵损失中
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss 
loss.backward()
optimizer.step()
  1. 隐藏层的相似度损失(余弦相似度损失)
optimizer.zero_grad()
ce_loss = nn.CrossEntropyLoss()
cosine_loss = nn.CosineEmbeddingLoss()

# 冻结教师网络的权重,只记录隐藏层的结果
with torch.no_grad(): 
    _, teacher_hidden_representation = teacher(inputs)

# 学生网络的推理结果
student_logits, student_hidden_representation = student(inputs) 

# 计算学生网络和教师网络在隐藏层的余弦损失
hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))

# 学生网络的交叉熵损失
label_loss = ce_loss(student_logits, labels)

# 将余弦相似度损失添加到交叉熵损失中
loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
  1. 中间层的回归损失(均方误差,MSE)
optimizer.zero_grad()
ce_loss = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

# 冻结教师网络的权重,只记录中间层的特征
with torch.no_grad(): 
    _, teacher_feature_map = teacher(inputs)

# 学生网络的推理结果
student_logits, regressor_feature_map = student(inputs) 

# 计算学生网络和教师网络在中间层的均方误差
hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)

# 学生网络的交叉熵损失
label_loss = ce_loss(student_logits, labels)

# 将中间层的均方误差添加到交叉熵损失中
loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号