问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

缓存与效果的取舍之——MHA到GQA(详细代码实现)

创作时间:
作者:
@小白创作中心

缓存与效果的取舍之——MHA到GQA(详细代码实现)

引用
CSDN
1.
https://m.blog.csdn.net/u010712012/article/details/145801580

在深度学习领域,Transformer模型因其卓越的性能和广泛的应用而备受关注。其中,注意力机制是Transformer模型的核心组成部分,而MHA(Multi-Head Attention)和GQA(Grouped-Query Attention)则是两种重要的注意力机制实现方式。本文将通过代码实现的方式,详细讲解这两种机制的工作原理及其在缓存与效果之间的取舍问题。

Multi-Head Attention

原始Transformer中的注意力机制也是MHA,每个head的head_size将变为原始embed_size的1/num_head,类似于group卷积,建立了很多个交流通道,每个通道关注的信息细节不同,也就是每个头可以关注到序列中不同子空间的特征。

import torch
from torch import nn

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        # 初始化QKV投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        if attention_mask != None:
            attention_scores += attention_scores * -1e9
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_probs, value)
        ## 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        output = self.o_linear(output)
        return output
    
    def split_head(self, x):
        batch_size = x.size()[0]
        return x.review(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

Grouped-Query Attention

GQA将query分成g组,每组的query共享KV,这种设计在保持模型效果的同时,有效地平衡了缓存需求。

import torch.nn.functional as F

class GroupQueryAttention(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, groups=2, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.groups = groups
        self.head_dim = embed_dim // num_heads
        self.group_heads = num_heads // groups
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, self.groups * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, self.groups * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, key_padding_mask=None):
        batch_size, seq_len, _ = x.size()
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.groups, self.head_dim).permute(0,2,3,1)
        v = v.view(batch_size, seq_len, self.groups, self.head_dim).transpose(1,2)
        k = k.unsqueeze(2).expand(-1, -1, self.group_heads, -1, -1).contiguous()
        k = k.view(batch_size, self.num_heads, self.head_dim, seq_len)
        v = v.unsqueeze(2).expand(-1, -1, self.group_heads, -1, -1).contiguous()
        v = v.view(batch_size, self.num_heads, seq_len, self.head_dim)
        attn_scores = torch.matmul(q,k) 
        attn_scores = attn_scores / (self.head_dim ** 0.5)
        if key_padding_mask is not None:
            mask = key_padding_mask.view(batch_size, 1, 1, seq_len)
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim = -1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        return self.out_proj(output)
    
if __name__ == "__main__":
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8
    groups = 2
    gqa = GroupQueryAttention(d_model, num_heads, groups)
    x = torch.randn(batch_size, seq_len, d_model)
    output = gqa(x)

通过对比MHA和GQA的实现,我们可以更深入地理解它们在缓存与效果之间的权衡。MHA通过多个head来捕捉不同子空间的特征,而GQA则通过分组的方式在保持效果的同时减少计算量和缓存需求。这些机制的选择和优化对于构建高效且高性能的Transformer模型至关重要。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号