问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Self-Attention机制的计算详解

创作时间:
作者:
@小白创作中心

Self-Attention机制的计算详解

引用
CSDN
1.
https://blog.csdn.net/qq_41915623/article/details/125161008

Attention的思想

Attention注意力的核心目标就是从众多信息中选择出对当前任务目标更关键的信息,将注意力放在上面。本质思想就是【从大量信息中】【有选择的筛选出】【少量重要信息】并【聚焦到这些重要信息上】,【忽略大多不重要的信息】。聚焦的过程体现在【权重系数】的计算上,权重越大越聚焦于其对应的value值上。即权重代表了信息的重要性,而value是其对应的信息。个人理解,就是对参数进行“加权求和”。

Self-Attention计算公式

其中,X表示输入的数据,Q, K, V对应内容如图,其值都是通过X和超参(先初始化,后通过训练优化)进行矩阵运算得来的。可以理解为:Self-Attention中的Q是对自身(self)输入的变换,而在传统的Attention中,Q来自于外部。

Self-Attention的计算实例

结合代码进行理解:

Step1: 初始化WQ, WK, WV矩阵

class BertSelfAttention(nn.Module):
    self.w_q = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.w_k = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.w_v = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

假设三种操作的输入都是同等维度的矩阵,这里每个特征维度都是768.即三者的维度:

WQ.shape = [768, 768]
WK.shape = [768, 768]
WV.shape = [768, 768]

Step2:定义输入

输入的特征维度也为768,即:每个字用768维来进行表示,如图所示:

即输入的X的维度为: [6, 768].

Step3:计算Q, K, V

由于维度的问题,需要调换以下顺序,以及可能会涉及到转置:

Q = X ⋅ WQ
K = X ⋅ WK
V = X ⋅ WV

根据以上公式,得到Q, K, V的维度:

Q.shape = [6, 768] * [768,768] = [6, 768]

K,V同理。其维度图如下:

Step4:根据公式计算注意力Attention

Attention(Q, K, V) = softmax(QKTdk)V

First:是Q,K矩阵相乘,维度变化:[6, 768] * [768, 6] = [6, 6],如图:

(1)首先用Q的第一行,即“我”字的768特征和K中“我”字的768为特征点乘求和,得到输出(0,0)位置的数值,这个数值就代表了“我想吃酸菜鱼”中“我”字对“我”字的注意力权重

(2)然后显而易见输出的第一行就是“我”字对“我想吃酸菜鱼”里面每个字的注意力权重;整个结果自然就是“我想吃酸菜鱼”里面每个字对其它字(包括自己)的注意力权重(就是一个数值)了.

Second:除以dk,dkdk表示特征维度,在本例中dk=768dk=768dk=768。之所以要除以这个数,是为了矩阵点乘后的范围,确保softmax的梯度稳定性。

Three:最后就是注意力权重和V矩阵相乘,如图所示:

(1)首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行相乘再求和,这个过程其实就相当于用每个字的权重对每个字的特征进行加权求和,

(2)然后再用“我”这个字对对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推最终也就得到了(L,768)的结果矩阵,和输入保持一致。

注意:注意力机制是没有位置信息的,所以需要引入位置编码。

引申

介绍transformer相关的内部结构。

4.1 Multi-Head Attention

即多头注意力机制,其作用是让模型从多个子空间中关注到不同方面的信息。其模型图如下:

其计算步骤如下

Step1:初始化多组WQ,WK,WV

如上图中初始化了三组,分别是:

W1Q,W1K,W1V
W2Q,W2K,W2V
W3Q,W3K,W3V

Step2:分别计算每组的Q,K,V

得到三组Q, K, V

Step3:分别按照之前的公式计算Attention

由图将计算到的Attention称之为:Z1,Z2,Z3。

Step4:将各组Attention(Z1,Z2,Z3)拼接然后进行线性变换映射到原来的空间中

如下图:

将三个Z1,Z2,Z3进行凭借,然后经过线性变换,即得到和原来输入同等维度的Z,也可以理解为最终的注意力。

4.2 Add & Norm

Add操作的目的是:借鉴了残差网络,防止退化

Norm操作的目的是:对向量进行标准化,以达到加速收敛的效果

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号