问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

知识蒸馏详解及PyTorch官网Demo案例

创作时间:
作者:
@小白创作中心

知识蒸馏详解及PyTorch官网Demo案例

引用
CSDN
1.
https://blog.csdn.net/qq_52048052/article/details/137168069

知识蒸馏(Knowledge Distillation, KD)是一种模型压缩方法,通过将大型模型(教师模型)的知识迁移到小型模型(学生模型)中,以实现模型的压缩和加速。本文将详细介绍知识蒸馏的核心概念、不同形式的知识表示以及具体的实现方法,并通过PyTorch示例代码展示如何实现知识蒸馏。

知识蒸馏(Knowledge Distillation, KD)

1. 简介

一种模型压缩方法

知识蒸馏的一般框架(如下图)包含三部分:知识、蒸馏算法、师生架构。

知识

知识可以分为三种形式:基于响应的(response-based)、基于特征的(feature-based)、基于关系的(relation-based)。

① 基于响应的知识(response-based)【常用】

学习的知识是教师模型最后一个输出层logits。由于logits实际上是类别概率分布,因此基于响应的知识蒸馏限制在监督学习

最流行的基于响应的图像分类知识被称为软目标(soft target)。基于响应的知识蒸馏具体架构如下图。后面具体介绍该类知识蒸馏。

② 基于特征的知识(feature-based)

学习的知识是教师模型中间层的基于特征的知识。下图为基于特征的知识蒸馏模型的通常架构。

③ 基于关系的知识(relation-based)

基于响应和基于特征的知识都使用了教师模型中特定层的输出,基于关系的知识进一步探索了不同层或数据样本的关系。下图为实例关系的知识蒸馏架构。

蒸馏机制

根据教师模型是否与学生模型同时更新,知识蒸馏的学习方案可分为离线(offline)蒸馏、在线(online)蒸馏、自蒸馏(self-distillation)

离线蒸馏(常用)

在离线蒸馏中,学生模型仅使用知识进行训练,而不与教师模型同时更新。学生模型独立地使用知识进行训练,目标是使学生模型的输出尽可能接近教师模型的输出。

大多数之前的知识蒸馏方法都是离线的。最初的知识蒸馏中,知识从预训练的教师模型转移到学生模型中,整个训练过程包括两个阶段:1)大型教师模型蒸馏前在训练样本训练;2)教师模型以logits(基于响应,生成软目标(soft target))或中间特征(基于特征)的形式提取知识,将其在蒸馏过程中指导学生模型的训练。

在线蒸馏

在线蒸馏时,教师模型和学生模型同步更新,而整个知识蒸馏框架都是端到端可训练的。

在线蒸馏是一种具有高效并行计算的单阶段端到端训练方案。然而,现有的在线方法(如相互学习)通常无法解决在线环境中的高容量教师,这使进一步探索在线环境中教师和学生模式之间的关系成为一个有趣的话题。

自蒸馏

在自蒸馏中,教师和学生模型使用相同的网络,这可以看作是在线蒸馏的一个特例。

从人类师生学习的角度可以直观地理解离线、在线和自蒸馏。

  • 离线蒸馏:知识渊博的教师教授学生知识
  • 在线蒸馏:教师和学生一起学习
  • 自我蒸馏:学生自己学习知识

师生架构

  • 教师模型(cumbersome model):已经训练好的,较为笨重的模型。
  • 学生模型:通过蒸馏,将教师模型中已经学习到的知识迁移到的新的轻量级的模型。

2. 学生模型的训练(基于响应的离线知识蒸馏)

hard target(硬目标)与 soft target(软目标)

  • hard target仅包含正样本信息
  • soft target具有更多信息,不仅包含正样本信息,还有相似负样本信息,比如左图的正样本标签为2,但由于写法与3相像,因此对标签3也给予一定的关注通过增大概率值;而右图的正样本标签2写法与7相像,因此对标签7也给予一定的关注。

具体到代码中就是加入蒸馏温度T。

蒸馏温度T

原来的softmax 将多分类的输出结果映射为概率值。q i = e z i ∑ j = 1 n e z j q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}}qi =∑j=1n ezj ezi ,其中z i z_izi 是模型的softmax层输出logits。

在进行知识蒸馏时,如果将教师模型的softmax输出,作为学生模型的s o f t − t a r g e t soft-targetsoft−target,那么负标签的值接近于0,对学生模型的损失函数贡献非常小,使得模型难以利用教师模型学到的知识。因此,提出蒸馏温度T的概念,使得softmax是输出更加平滑。

加入蒸馏温度T TT后的softmax

q i = e ( z i / T ) ∑ j = 1 n e ( z j / T ) q_i=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}}qi =∑j=1n e(zj /T)e(zi /T)

实验:当温度T TT越高时,负标签的概率值的变化。

正标签为第1个元素,当温度T TT越高时,负标签的概率值相对被放得越大。在训练时,由于损失函数的惩罚,模型需要对负标签给予一定的关注;从而达到在学习老师模型时,一次训练不仅仅可以学到正样本的特征,也可以学到相似负样本的特征。

import numpy as np

def softmax(x):
    x_exp = np.exp(x)
    return x_exp / x_exp.sum()

def softmax_t(x, T):
    # T是蒸馏温度
    x_exp = np.exp(x / T)
    return x_exp / x_exp.sum()

output = np.array([5, 1.3, 2])
print('temperature is 5: ', softmax_t(output, 5))
print('temperature is 10: ', softmax_t(output, 10))
print('temperature is 100: ', softmax_t(output, 100))

知识蒸馏训练的具体步骤

  1. 训练好Teacher模型
  2. 利用高温T h i g h T_{high}Thigh 产生s o f t − t a r g e t soft-targetsoft−target
  3. 使用{s o f t − t a r g e t , T h i g h soft-target, T_{high}soft−target,Thigh }和{h a r d − t a r g e t , T = 1 hard-target, T=1hard−target,T=1},同时训练 Student 模型
  4. 设置蒸馏温度T = 1 T=1T=1,Student模型线上做推理

高温蒸馏过程的损失函数

学生损失函数student loss即,L h a r d = − ∑ j = 1 n l j l o g ( q j ) , q i = e z i ∑ j = 1 n e z j L_{hard}=-\sum_{j=1}^nl_jlog(q_j),q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}}Lhard =−j=1∑n lj log(qj ),qi =∑j=1n ezj ezi

蒸馏损失函数distillation loss即,L s o f t = − ∑ j = 1 n p j T l o g ( q j T ) , p i T = e ( v i / T ) ∑ j = 1 n e ( v j / T ) , q i T = e ( z i / T ) ∑ j = 1 n e ( z j / T ) L_{soft}=-\sum_{j=1}^np_j^Tlog(q_j^T),p_i^T=\frac{e^{(v_i/T)}}{\sum_{j=1}^n{e^{(v_j/T)}}},q_i^T=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}}Lsoft =−j=1∑n pjT log(qjT ),piT =∑j=1n e(vj /T)e(vi /T) ,qiT =∑j=1n e(zj /T)e(zi /T)

高温蒸馏过程的损失函数定义为:L = α L s o f t + β L h a r d L=\alpha L_{soft}+\beta L_{hard}L=αLsoft +βLhard

其中,l i l_ili 为第i个ground truth值,z i z_izi 为学生模型的第i个输出logits值,v i v_ivi 为老师模型的第i个输出logits值,α \alphaα和β \betaβ为超参数。

pytorch官网知识蒸馏demo

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号