问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

一文读懂Transformer:强大的算法模型详解

创作时间:
作者:
@小白创作中心

一文读懂Transformer:强大的算法模型详解

引用
CSDN
1.
https://m.blog.csdn.net/u011376987/article/details/141760387

Transformer模型是当前自然语言处理领域最核心的技术之一,它通过引入注意力机制,突破了传统序列模型的局限,实现了并行计算,大大提高了训练效率和模型性能。本文将从Transformer模型的基本结构出发,深入探讨其核心组件,包括输入嵌入、位置编码、自注意力机制、多头注意力机制等,并提供完整的PyTorch实现代码。

Transformer模型是一种基于注意力机制的深度学习模型,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成和语义理解。它最初由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它突破了传统序列模型(如RNN和LSTM)的局限,能够并行处理序列数据,从而大大提高了训练效率和模型性能。

Transformer模型的基本结构

Transformer模型由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。编码器将输入序列编码为一个固定长度的上下文向量,解码器则根据这个上下文向量生成输出序列。编码器和解码器各由多个层(Layer)堆叠而成。

编码器(Encoder)

编码器的主要作用是将输入序列转换为一组上下文向量,供解码器使用。每个编码器层包括两个主要的子层:

  • 多头自注意力机制(Multi-Head Self-Attention):捕捉输入序列中不同位置之间的依赖关系。通过不同的注意力头(Attention Heads),模型可以从多个不同的角度来看待输入序列。
  • 前馈神经网络(Feed-Forward Neural Network, FFN):对经过注意力机制处理的序列进行进一步的非线性变换。

每个子层之后都会使用残差连接(Residual Connection)和层归一化(Layer Normalization),这有助于避免梯度消失问题并加快训练收敛速度。

解码器(Decoder)

与编码器类似,解码器也由多个层组成,每个解码器层包含三个子层:

  • 掩码多头自注意力机制:与编码器中的多头自注意力机制类似,但在解码器中,解码器的多头自注意力机制是掩蔽(Masked)的,防止在预测下一个单词时看到未来的信息。
  • 编码器-解码器多头注意力机制(Encoder-Decoder Attention):该注意力机制允许解码器访问编码器的输出,这样解码器就可以根据编码器生成的上下文向量来生成输出序列。
  • 前馈神经网络(Feed-Forward Neural Network, FFN):对经过注意力机制处理的序列进行进一步的非线性变换。

每个子层同样有残差连接和层归一化。

核心组件

下面,我们来详细描述一下Transformer中的核心组件。

输入嵌入

输入嵌入是将输入文本序列中的单词或符号映射为高维向量的过程。在Transformer模型中,文本首先被标记化为单词或子词,然后每个标记被映射为一个固定长度的向量。这些向量通常是通过查找嵌入矩阵(embedding matrix)得到的,该矩阵是在训练过程中学习得到的。输入嵌入的作用是将离散的符号转换为连续的、可以直接输入到神经网络中的向量表示,使得模型能够处理和理解输入数据。

位置编码

位置编码(Positional Encoding)是Transformer模型中的一个关键组件,用于在模型中引入序列位置信息。由于Transformer模型不使用传统的循环神经网络(RNN)结构,它无法像这些传统模型那样通过其结构直接捕获输入数据的位置信息。因此,需要通过位置编码来显式地提供序列中的位置信息。位置编码通常使用正弦和余弦函数来生成。

对于位置pos和嵌入维度中的第i个维度:

其中:

  • pos是位置索引。
  • i是维度索引。
  • d_model是嵌入向量的维度。

自注意力机制

自注意力机制是Transformer的核心创新之一。它允许模型在计算某个位置的输出时,考虑输入序列中所有其他位置的信息。具体地,对于每个输入位置,自注意力机制会计算该位置与其他所有位置的相似度(通过点积操作),并使用这些相似度作为权重来加权求和其他位置的输入表示。

自注意力机制的关键步骤包括:

  • Query、Key、Value向量的生成:对输入嵌入进行线性变换,生成三个不同的向量,即查询向量(Query)、键向量(Key)和值向量(Value)。

  • 每个输入向量x_i,通过三个线性变换分别映射为查询向量q_i、键向量k_i和值向量v_i

其中,q_ik_iv_i是可学习的权重矩阵。

  • 注意力得分的计算:通过点积计算查询向量与所有键向量之间的相似度,得到注意力得分矩阵。对于每个查询向量q_i,通过点积的方式计算它与所有键向量k_j的相似度,得到注意力分数。为了稳定训练过程,这些分数会除以sqrt(d_k),其中d_k是键向量的维度。

  • 加权求和:使用Softmax函数将注意力得分转换为权重,然后对所有值向量进行加权求和,得到最终的输出表示。

多头注意力机制

多头注意力机制是对自注意力机制的扩展。通过并行地执行多次自注意力机制,可以让模型从不同的角度(即不同的“头”)学习输入序列中的信息。每个头都有自己独立的查询、键和值的线性变换,然后分别执行自注意力操作,最后将这些头的输出进行拼接,并通过线性变换生成最终的多头注意力输出。

具体来说,假设有h个注意力头,每个头分别计算如下:

其中,W_i^QW_i^KW_i^V是第i个头的查询、键和值的权重矩阵。

然后,将所有头的输出连接起来,并通过线性变换:

其中,W^O是输出的权重矩阵。

多头注意力机制的优点在于它能够捕捉到不同的语义关系和特征,从而增强模型的表达能力。

前馈神经网络

每个编码器和解码器层中的前馈神经网络是一个两层的全连接神经网络,作用是对每个位置的表示进行独立的非线性变换。公式表示如下:

其中,W_1W_2是权重矩阵,b_1b_2是偏置向量。

层归一化和残差连接

为了防止深层网络的梯度消失问题,Transformer在每个子层后使用了残差连接,并紧跟层归一化。

其中,SubLayer(x)可以是多头注意力机制或前馈神经网络的输出。

掩码多头自注意力

在标准的多头注意力机制中,每个位置的查询(Query)会与所有位置的键(Key)进行点积计算,得到注意力分数,然后与值(Value)加权求和,生成最终的输出。然而,在解码器中,生成序列时不能访问未来的信息。因此需要使用掩码(Mask)机制来屏蔽掉未来位置的信息。

具体来说,在计算注意力得分时,对未来的位置进行屏蔽,将这些位置的得分设为负无穷大,使得Softmax归一化后的权重为零。

编码器-解码器多头注意力

在解码器中的Multi-head Attention也叫做Encoder-Decoder Attention,它的Query来自解码器的self-attention,而Key、Value则是编码器的输出。

案例代码

下面是一个使用PyTorch实现Transformer模型的简单示例代码。该示例展示了如何构建一个基本的Transformer模型并使用它进行序列到序列的任务,例如机器翻译。

import torch
import torch.nn as nn
import torch.optim as optim
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class TransformerModel(nn.Module):
    def __init__(self, input_dim, output_dim, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.embedding = nn.Embedding(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, output_dim)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# Example usage:
input_dim = 1000  # Vocabulary size
output_dim = 1000  # Output size
seq_length = 10  # Length of the sequence

# Create the model
model = TransformerModel(input_dim=input_dim, output_dim=output_dim)

# Example data
src = torch.randint(0, input_dim, (seq_length, 32))  # (sequence_length, batch_size)
src_mask = generate_square_subsequent_mask(seq_length)

# Forward pass
output = model(src, src_mask)
print(output.shape)  # Expected output: [sequence_length, batch_size, output_dim]

# Define a simple loss and optimizer for training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example training loop
for epoch in range(10):  # Number of epochs
    optimizer.zero_grad()
    output = model(src, src_mask)
    loss = criterion(output.view(-1, output_dim), src.view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号