问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

如何具体理解Self Attention中的Q、K、V以及计算过程

创作时间:
作者:
@小白创作中心

如何具体理解Self Attention中的Q、K、V以及计算过程

引用
CSDN
1.
https://blog.csdn.net/xfysq_/article/details/137243237

Self-Attention机制是Transformer模型的核心组成部分,它通过计算输入序列中不同位置之间的关系来捕捉长距离依赖。本文将详细解释Self-Attention中的Q(Query)、K(Key)、V(Value)以及它们的计算过程,并通过代码示例展示具体实现。

一、计算过程理解

  1. 我们直接用PyTorch实现一个Self-Attention:

首先定义三个线性变换矩阵,query, key, value:

class BertSelfAttention(nn.Module):
    self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

注意,这里的query, key, value只是一种操作(线性变换)的名称,实际的Q/K/V是这三个线性操作的输出,三个变换的输入都是768维,输出都是768维,也就是三个线性变换矩阵的维度都为(768, 768)。

  1. 假设三种操作的输入都是同一个矩阵,这里暂且定为长度为6的句子,每个token的特征维度是768,那么输入就是(6, 768),每一行就是一个字的词向量,像这样:


图1 输入词向量矩阵

乘以上面代码中的三种线性变换操作就得到了Q/K/V三个矩阵,他们的维度为(6, 768) * (768, 768) = (6, 768),维度其实没变,即此刻的Q/K/V分别为:


图2 输入词向量矩阵与线性变换矩阵相乘输出Q、K、V矩阵

代码为:

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(6, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
  1. 计算Self Attention:

Attention(Q,K,V)=Softmax(QKTdk)V

(1) 首先是Q和K矩阵相乘,(6, 768) × (6, 768)T = (6, 6),如图3:


图3 Q和K矩阵相乘的结果

具体的计算过程,首先用Q的第一行,即“我”字的768特征和K中“我”字的768为特征点乘求和,得到输出矩阵(0,0)位置的数值,这个数值就代表了“我想吃酸菜鱼”中“我”字对“我”字的注意力权重。最终输出矩阵的第一行就是“我”字对“我想吃酸菜鱼”里面每个字的注意力权重。整个输出矩阵就是“我想吃酸菜鱼”里面每个字对其它字(包括自己)的注意力权重。

(2) 然后是除以dk,这个dim就是768。

1)至于为什么要除以这个数值?主要是为了缩小点积范围,确保softmax梯度稳定性。

2)为什么要Softmax?主要是为了保证注意力权重的非负性,同时增加非线性。

(3) 然后就是刚才的注意力权重和V矩阵相乘,如图4:


图4 注意力权重和V矩阵相乘

注意力权重 × VALUE矩阵 = 最终结果

首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行相乘再求和,这个过程其实就相当于用每个字的权重对每个字的特征进行加权求和。然后再用“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推最终也就得到了(6, 768)的结果矩阵,和输入保持一致。

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        out = torch.matmul(attention_probs, V)
        return out
  1. 为什么叫自注意力机制?

因为可以看到Q/K/V都是通过同一句话的输入算出来的,按照上面的流程也就是一句话内每个字对其它字(包括自己)的权重分配。

如果不是自注意力的话,Q来自于句A,K,V来自于句B。

  1. 注意,K/V中,如果同时替换任意两个字的位置,对最终的结果是不会有影响的,也就是说注意力机制是没有位置信息的,不像CNN/RNN/LSTM,这也是为什么要引入位置embeding的原因。

二、整体代码

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        out = torch.matmul(attention_probs, V)
        return out
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号