问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

知识蒸馏原理解读

创作时间:
作者:
@小白创作中心

知识蒸馏原理解读

引用
1
来源
1.
https://www.cnblogs.com/keanshi/p/18698024

知识蒸馏是深度学习领域的重要技术,它通过将大模型的知识转移到小模型中,实现了模型的压缩和加速。本文从热力学中的蒸馏概念出发,详细介绍了知识蒸馏的基本原理、理论基础和具体方法,帮助读者深入理解这一技术的核心思想和实现方式。

背景

首先我们先了解一下蒸馏的定义:蒸馏(distillation)是一种热力学的分离工艺,它利用混合液体或液-固体系中各组分沸点不同,使低沸点组分蒸发,再冷凝以分离整个组分的单元操作过程,是蒸发和冷凝两种单元操作的联合。[1]

从其定义我们可以得知,蒸馏的核心是“分离”;在神经网络中,“知识”可以抽象为模型蕴含的参数信息,因此 知识蒸馏就是将模型的部分参数(知识)分离出来提供给其他的模型 。这一技术的理论来自于 2015 年 Hinton 等人发表的一篇论文:Distilling the Knowledge in a Neural Network。

理论基础

Teacher 和 Student

一般认为,大模型往往比小模型能学习更多、更复杂的知识。但在落地应用时,大模型存在着部署困难、推理效率低、资源耗费多等问题,因此 Hinton 等人想出将“训练”和“部署”解耦:训练时使用参数量多的“Teacher”(大模型),部署时使用速度快的“Student”模型(小模型),即 Teacher -- Student 结构

  • Teacher 是“知识”的输出者,记为 $Ne{t}_{T}$,不受模型参数、架构限制,但需要被充分预训练,任务与 Student 匹配;

  • Student 是“知识”的接受者,记为 $Ne{t}_{S}$,需要采用架构相对简单、参数量较小的模型。

Soft Target

在知识蒸馏时,我们做的就是让 $Ne{t}_{S}$ 学习泛化能力比较强的 $Ne{t}_{T}$。一个很直白且高效的迁移泛化能力的方法就是: 使用 soft target

可以知道 soft target 分布的熵更高,含有的信息也就越多。

Softmax

传统的 softmax 函数如下:

${q}_{i}=\frac{\mathrm{exp}\left({z}_{i}\right)}{\sum _{j}\mathrm{exp}\left({z}_{j}\right)}$

但是如果当 softmax 之后的分布熵相对较小时,负标签的值都很接近 0,对损失函数的贡献非常小,小到可以忽略不计,此时 温度 $T$ 就派上了用场。

下面的公式时加了温度这个变量之后的 softmax 函数:

${q}_{i}=\frac{\mathrm{exp}\left({z}_{i}/T\right)}{\sum _{j}\mathrm{exp}\left({z}_{j}/T\right)}$

$T$ 越高,softmax 的输出分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

方法

图片

知识蒸馏需要用到一个预训练的教师模型以及一个结构简单的学生模型,最终的损失需要用到老师对学生的指导 ${L}_{soft}$(distillation loss)以及学生参考标准答案的误差 ${L}_{hard}$(student loss)。即:

$\begin{array}{rl}& L=\alpha {L}_{soft}+\beta {L}_{hard}\\ & {L}_{soft}=-\sum _{j}^{N}{p}_{j}^{T}\mathrm{log}\left({q}_{j}^{T}\right)\\ & {L}_{hard}=-\sum _{j}^{N}{c}_{j}\mathrm{log}\left({q}_{j}^{1}\right)\end{array}$

其中,${p}_{i}^{T}=\frac{\mathrm{exp}\left({v}_{i}/T\right)}{\sum _{k}^{N}\mathrm{exp}\left({v}_{k}/T\right)}$,${q}_{i}^{T}=\frac{\mathrm{exp}\left({z}_{i}/T\right)}{\sum _{k}^{N}\mathrm{exp}\left({z}_{k}/T\right)}$;${c}_{j}$ 为真实标签,${v}_{i}$ 为 $Ne{t}_{T}$ 的 logits,${z}_{i}$ 为 $Ne{t}_{S}$ 的 logits。

这个训练过程比较通俗易懂的解释就是:老师的各方面知识水平远在学生以上,他会挑选重点进行讲授(由温度控制);但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。

对于温度的选择,一般可以考虑:

  1. 从有部分信息量的负标签中学习 --> 温度要高一些

  2. 防止受负标签中噪声的影响 --> 温度要低一些

  3. $Ne{t}_{S}$ 参数量比较小的时候,相对比较低的温度就可以了

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号