问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Transformer中的自注意力机制和多头自注意力机制详解

创作时间:
作者:
@小白创作中心

Transformer中的自注意力机制和多头自注意力机制详解

引用
CSDN
1.
https://m.blog.csdn.net/z4400840/article/details/144715232

Transformer中的自注意力机制和多头自注意力机制是其核心组成部分,用于建模序列中不同位置之间的关系,从而捕获全局上下文信息。

自注意力机制

自注意力机制允许模型在处理输入序列的每一个位置时,动态地关注序列中其他位置的信息。这种机制使得模型能够捕捉到序列中各个部分之间的依赖关系,无论这些依赖关系是局部的还是全局的。

原理

自注意力是一种计算序列中每个位置与序列中其他所有位置之间相关性的机制,用于生成该位置的上下文表示。

对于输入序列:

  1. 线性变换

每个输入向量 通过三个线性变换得到查询(Query)、键(Key)、值(Value)矩阵

其中 是可训练的权重矩阵。

  1. 计算注意力分数

通过矩阵 Q 和矩阵 K 的点积计算注意力分数

这里的 是一个缩放因子

  1. 计算注意力权重

对注意力分数应用 Softmax 函数,得到每个位置的注意力权重

  1. 加权求和

用上述注意力权重对值 V 进行加权求和,得到每个位置的上下文表示。

特点与优势

  • 并行计算

自注意力机制可以并行处理序列中的所有位置,显著提高计算效率。

  • 长距离依赖

能够有效捕捉序列中任意两个位置之间的依赖关系,解决了传统 RNN 在处理长序列时的梯度消失问题。

  • 动态权重

注意力权重是动态计算的,能够根据输入内容自适应地调整关注的重点。

多头自注意力机制

多头自注意力是对自注意力的扩展,通过引入多个“头”(head),使模型能够在不同的子空间中并行地进行自注意力计算,从而捕捉更多样化的特征和关系。

原理

  1. 多个子空间投影

输入数据 X 通过多个独立的线性变换生成不同的查询、键、值矩阵

其中 是注意力头的数量, 是每个头独立的可训练的权重矩阵。

  1. 独立计算自注意力

每个头独立计算自注意力

  1. 拼接与线性变换

将 H 个头的输出拼接后,进行最终的线性变换,得到多头自注意力的输出。

其中 是最终线性变换的权重矩阵。

优点

  • 捕获多样化特征

每个头可以学习到不同的语义模式或依赖关系,一些头可能更关注局部上下文,一些头可能更关注全局依赖。

  • 增强模型能力

多个头并行处理,增加模型表达能力。

  • 提升稳定性

通过分散注意力到多个子空间,降低单头注意力的偏差。

示例代码

以下是一个实现多头自注意力的示例代码,使用了 PyTorch 框架。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        
        assert embed_size % num_heads == 0, "Embedding size must be divisible by the number of heads"
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        # Linear layers for query, key, and value
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        # Final linear layer to combine heads
        self.fc_out = nn.Linear(embed_size, embed_size)
    
    def forward(self, x):
        batch_size, seq_length, embed_size = x.shape
        # Project inputs to query, key, value spaces
        Q = self.query(x)  # (batch_size, seq_length, embed_size)
        K = self.key(x)    # (batch_size, seq_length, embed_size)
        V = self.value(x)  # (batch_size, seq_length, embed_size)
        # Reshape for multi-head: (batch_size, seq_length, num_heads, head_dim)
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim)
        # Transpose for attention calculation: (batch_size, num_heads, seq_length, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        # Scaled Dot-Product Attention
        energy = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attention = F.softmax(energy, dim=-1)  # (batch_size, num_heads, seq_length, seq_length)
        # Apply attention weights to values
        out = torch.matmul(attention, V)  # (batch_size, num_heads, seq_length, head_dim)
        # Concatenate heads: (batch_size, seq_length, embed_size)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size)
        # Final linear transformation
        out = self.fc_out(out)  # (batch_size, seq_length, embed_size)
        return out

# Example usage
if __name__ == "__main__":
    embed_size = 64
    num_heads = 8
    seq_length = 10
    batch_size = 32
    x = torch.rand(batch_size, seq_length, embed_size)  # Random input tensor
    multihead_attention = MultiHeadSelfAttention(embed_size, num_heads)
    out = multihead_attention(x)
    print(out.shape)  # Expected: (batch_size, seq_length, embed_size)

代码解析

  1. 初始化
  • embed_size 是输入嵌入的维度。
  • num_heads 是多头注意力的头数,则每个头的维度为 head_dim = embed_size // num_heads
  1. 线性变换

query, key, 和 value 是投影层,将输入投影到查询、键和值空间。

  1. 多头拆分

将嵌入分解为多个头,用 viewtranspose 调整维度以适配多头计算。

  1. 注意力计算

  2. 头的组合

将多个头的输出拼接,并通过线性层 fc_out 整合为最终输出。

  1. 输出

输出的形状与输入相同,但包含了上下文信息。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号