Transformer中的自注意力机制和多头自注意力机制详解
Transformer中的自注意力机制和多头自注意力机制详解
Transformer中的自注意力机制和多头自注意力机制是其核心组成部分,用于建模序列中不同位置之间的关系,从而捕获全局上下文信息。
自注意力机制
自注意力机制允许模型在处理输入序列的每一个位置时,动态地关注序列中其他位置的信息。这种机制使得模型能够捕捉到序列中各个部分之间的依赖关系,无论这些依赖关系是局部的还是全局的。
原理
自注意力是一种计算序列中每个位置与序列中其他所有位置之间相关性的机制,用于生成该位置的上下文表示。
对于输入序列:
- 线性变换
每个输入向量 通过三个线性变换得到查询(Query)、键(Key)、值(Value)矩阵
其中 是可训练的权重矩阵。
- 计算注意力分数
通过矩阵 Q 和矩阵 K 的点积计算注意力分数
这里的 是一个缩放因子
- 计算注意力权重
对注意力分数应用 Softmax 函数,得到每个位置的注意力权重
- 加权求和
用上述注意力权重对值 V 进行加权求和,得到每个位置的上下文表示。
特点与优势
- 并行计算
自注意力机制可以并行处理序列中的所有位置,显著提高计算效率。
- 长距离依赖
能够有效捕捉序列中任意两个位置之间的依赖关系,解决了传统 RNN 在处理长序列时的梯度消失问题。
- 动态权重
注意力权重是动态计算的,能够根据输入内容自适应地调整关注的重点。
多头自注意力机制
多头自注意力是对自注意力的扩展,通过引入多个“头”(head),使模型能够在不同的子空间中并行地进行自注意力计算,从而捕捉更多样化的特征和关系。
原理
- 多个子空间投影
输入数据 X 通过多个独立的线性变换生成不同的查询、键、值矩阵
其中 是注意力头的数量, 是每个头独立的可训练的权重矩阵。
- 独立计算自注意力
每个头独立计算自注意力
- 拼接与线性变换
将 H 个头的输出拼接后,进行最终的线性变换,得到多头自注意力的输出。
其中 是最终线性变换的权重矩阵。
优点
- 捕获多样化特征
每个头可以学习到不同的语义模式或依赖关系,一些头可能更关注局部上下文,一些头可能更关注全局依赖。
- 增强模型能力
多个头并行处理,增加模型表达能力。
- 提升稳定性
通过分散注意力到多个子空间,降低单头注意力的偏差。
示例代码
以下是一个实现多头自注意力的示例代码,使用了 PyTorch 框架。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert embed_size % num_heads == 0, "Embedding size must be divisible by the number of heads"
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
# Linear layers for query, key, and value
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
# Final linear layer to combine heads
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x):
batch_size, seq_length, embed_size = x.shape
# Project inputs to query, key, value spaces
Q = self.query(x) # (batch_size, seq_length, embed_size)
K = self.key(x) # (batch_size, seq_length, embed_size)
V = self.value(x) # (batch_size, seq_length, embed_size)
# Reshape for multi-head: (batch_size, seq_length, num_heads, head_dim)
Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_length, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_length, self.num_heads, self.head_dim)
# Transpose for attention calculation: (batch_size, num_heads, seq_length, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Scaled Dot-Product Attention
energy = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
attention = F.softmax(energy, dim=-1) # (batch_size, num_heads, seq_length, seq_length)
# Apply attention weights to values
out = torch.matmul(attention, V) # (batch_size, num_heads, seq_length, head_dim)
# Concatenate heads: (batch_size, seq_length, embed_size)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size)
# Final linear transformation
out = self.fc_out(out) # (batch_size, seq_length, embed_size)
return out
# Example usage
if __name__ == "__main__":
embed_size = 64
num_heads = 8
seq_length = 10
batch_size = 32
x = torch.rand(batch_size, seq_length, embed_size) # Random input tensor
multihead_attention = MultiHeadSelfAttention(embed_size, num_heads)
out = multihead_attention(x)
print(out.shape) # Expected: (batch_size, seq_length, embed_size)
代码解析
- 初始化
embed_size
是输入嵌入的维度。num_heads
是多头注意力的头数,则每个头的维度为head_dim = embed_size // num_heads
。
- 线性变换
query
, key
, 和 value
是投影层,将输入投影到查询、键和值空间。
- 多头拆分
将嵌入分解为多个头,用 view
和 transpose
调整维度以适配多头计算。
注意力计算
头的组合
将多个头的输出拼接,并通过线性层 fc_out
整合为最终输出。
- 输出
输出的形状与输入相同,但包含了上下文信息。