自注意力机制&多头自注意力机制:技术背景、原理分析及基于Pytorch的代码实现
自注意力机制&多头自注意力机制:技术背景、原理分析及基于Pytorch的代码实现
自注意力机制和多头自注意力机制的提出源于谷歌的Vaswani 等人于2017年发表的著名论文 《Attention Is All You Need》,作为Transformer架构的核心技术,其被广泛应用于自然语言处理(NLP)和计算机视觉(CV)等领域,为后来的如BERT、GPT等许多先进的模型奠定了基础。
本文将从技术背景、技术特点、技术原理、维度分析、重点解析及代码实现方面详细介绍自注意力机制和多头自注意力机制。
技术背景
在自然语言处理任务中,自注意力机制和多头自注意力机制提出之前的循环神经网络(RNN)和长短期记忆网络(LSTM)虽然也能够处理序列数据,但它们的计算存在序列依赖性(即必须按照顺序处理),导致训练和推理较慢且难以捕捉长距离依赖关系,二者的提出就是为了克服这一问题。
技术特点
(1)自注意力机制是一种将单个序列内不同位置相关联的注意力机制,以便计算序列的表示,即对序列当中所有位置之间的相互关系建模,使得每个位置的表示不仅依赖于该位置的信息,还能考虑到序列中其它位置的信息,从而能够快速捕捉长距离依赖关系。
(2)多头自注意力机制是自注意力机制的拓展,主要是为了解决自注意力机制会过度地将注意力集中于自身的位置的问题,同时将输入序列划分到多个子空间中分别执行自注意力计算有助于增强模型的特征表达能力。
(3)自注意力机制不依赖序列的顺序处理,可以一次性计算整个序列,得以在训练和推理时进行并行计算,从而大大提高了计算效率。
(4)不仅适用于文本序列,也可以应用于图像、音频等不同类型的数据。
技术原理
(1)自注意力机制和多头自注意力机制的原理如下图所示,左图代表自注意力机制(Self-Attention),右图代表多头自注意力机制(Multi-Head Self-Attention)。
(2)自注意力机制的核心原理是首先将输入序列通过三个线性映射层映射到三个独立的特征空间中,从而生成三个特征向量:查询向量Q、键向量K和值向量V,然后先计算查询向量Q和键向量K的转置的点积,再进行缩放(Scale)并经过Softmax激活函数得到自注意力分数Att_score,最后使用自注意力分数Att_score加权值向量V生成最终的自注意力输出。
(3)多头自注意力机制的核心原理是首先将输入序列划分到h个头上,然后每个头单独地执行自注意力机制计算,最后将多个头的输出向量拼接(Concat)在一起再经过一个线性映射层(Linear Layer)得到最终的输出。
维度分析
(1)如上图所示,在多头自注意力中如果**输入样本大小为n×d_model,即每个样本包含n个特征,每个特征的维度为d_model。**将这个样本划分到h个头上,则每个头上的样本大小为n×(d_model/h)。
(2)在自注意力机制中,一般情况下直接设置**“d_k=d_q=d_v=(d_model/h)”当输入样本大小为n×(d_model/h)**时,将其映射到Q、K和V的三个可学习参数矩阵大小都为(d_model/h)×(d_model/h),最后得到的映射结果大小都为n×(d_model/h)。
(3)Q与K^T点积之后的大小为n×n,再经缩放和Softmax不改变大小,得到Att_s core大小为n×n,最后与V加权得到的输出大小为n_(d_model/h),即与输入大小相同。
(4)回到多头自注意力机制,将每个头的输出在最后一个维度上拼接在一起后得到结果大小为n×d_mode****l,即多头自注意力机制的输入样本大小,最后经线性映射到输出不改变大小。
重点解析
(1)Q、K和V从何而来又为何刚好是它们三者?
自注意力机制需要计算输入序列(样本)中每个位置(特征)之间的相互关系,那么每个输入特征都需要在计算过程中扮演三种角色:
查询(Query,Q): 专为当前特征所用,去和其它特征做计算求相似度。
键(Key,K):专门应付其它特征,为其它特征提供特征信息参与相似度计算。
值(Value,V):代表输入特征本身,专门用来与计算得到的注意力分数做加权。
至于为什么不直接复用原始特征?则是因为通过线性变换可以增强模型的学习能力,模型可以从不同角度表示输入特征,使特征的角色分工明确,减少干扰,提高灵活性。
(2)为什么Q与K^T的点积之后要进行缩放?
要进行缩放的根本原因在于向量的维度越大,两者点积之后的方差会变大,导致数据差异性变大。由于Softmax激活函数是通过指数函数来将输入转换为概率分布的,而指数函数对于输入数值的差异比较敏感,所以当入分数差异过大将会出现较大的分数主导归一化过程使其值接近于1,而较小的分数输出值接近于0。由于Softmax函数在值接近0或1时梯度为0,从而会造成模型反向传播时的梯度消失现象,计算公式如下:
[备注:假设向量长度为d_k,且都服从独立同分布,均值为0,方差为1的正态分布,点积之后方差会变成d_k,而缩放就是将点积结果归一化成一个均值为0,方差为1的向量]
(3)自注意力机制为什么会过度地将注意力集中于当前位置?
根据自注意力机制的计算原理,如果Q与K来自于同一个输入特征,那么这两个向量可能具有较高的自相关性,导致在计算点积时具有较大的值,从而使该值在经过Softmax时占据主导地位,产生较大的注意力权重,因此让注意力过度集中于当前位置。
(4)多头自注意力机制中的“多头”有什么优势?
自注意力机制会出现过度地将注意力集中于当前位置的现象,而如果将输入划分到多个子空间中独立地计算自注意力,那么每个头的输出只代表输入特征不同尺度(局部特征)上的特征表示,即不同的头所关注的范围和模式不同,能够有效分散注意力的集中性。
同时,单个头的输出被拼接融合之后能够形成更全面的特征表达,即单个头的局部偏差能被其它头的多样化特征弥补,使整体表现更加鲁棒。此外,多头独立计算也分散了单一头的权重训练压力。
代码实现
(1)自注意力机制代码实现
(2)多头自注意力机制代码实现