问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Transformer模型详解:从架构到维度变换

创作时间:
作者:
@小白创作中心

Transformer模型详解:从架构到维度变换

引用
51CTO
1.
https://blog.51cto.com/u_16163453/12995450

自2017年Google在《Attention is All You Need》中提出Transformer模型以来,Transformer已经成为自然语言处理(NLP)领域的核心架构之一。本文将详细介绍Transformer的架构、维度变换的过程等,并结合翻译例子,更好地理解Transformer的工作原理。


Transformer模型架构图

在Transformer出现之前,序列到序列(Seq2Seq)模型主要依赖于循环神经网络(RNN)和长短期记忆网络(LSTM)。然而,RNN和LSTM在处理长序列时存在梯度消失和计算效率低下的问题。Transformer模型的出现,彻底改变了这一局面。Transformer模型完全摒弃了RNN的结构,采用了一种全新的架构——自注意力机制,使得模型能够并行处理序列数据,并且在捕捉长距离依赖关系方面表现出色。自注意力是Transformer的非常重要的创新点,所以开篇先介绍一下注意力,随后我们再按照输入、编码器、解码器、输出的结构来介绍Transformer。

1. 注意力

介绍自注意力之前,先介绍一下注意力。

在处理序列数据(如文本、语音或时间序列)时,传统的循环神经网络(RNN)或卷积神经网络(CNN)倾向于使用固定长度的向量表示整个输入序列。这种方法可能会导致信息丢失,特别是当输入序列很长时。但是注意力机制通过学习一组权重,可以灵活地关注输入序列中与当前任务最相关的部分,从而有效缓解信息压缩的问题。

在计算输出时,模型通过“注意力分数”来评估输入序列中每个元素的重要性,并生成一个加权的上下文向量。输入序列的每个元素通常会被嵌入到一个固定维度的向量空间中表示为X=[x1,x2,…,xn],然后通过以下步骤计算注意力:

Step 1:计算注意力分数

每个输入 xi 与 q 计算出一个分数 si:

si = score(q, xi)

q是什么?q就是一个矩阵,下面解释(包括下面的v、Wk都是矩阵)

通常使用以下几种函数来计算分数:

  • 点积 (Dot Product): si = q · xi
  • 加性函数 (Additive): si = vTtanh(Wqq + Wkxi)
  • 缩放点积 (Scaled Dot Product):

其中,dk是K向量的维度大小

Step 2:计算注意力权重

将分数 si 转化为概率分布,通常使用 softmax 函数:

这里,αi表示输入 xi 的注意力权重,

Step 3: 加权求和

根据权重 αi加权输入 xi 得到上下文向量(Context Vector)c:

例子:英文到中文翻译。假设我们要将以下英文句子翻译成中文:

输入序列:

[I, am, eating, an, apple]

输出序列:

[我, 正在, 吃, 一个, 苹果]

当模型生成目标序列中的某个词时,比如“苹果”,它需要知道输入句子中哪个词最相关。

注意力分配示例:

在生成 “苹果” 时,注意力机制计算输入序列中每个词的相关性,分配权重(假设为分数α):

输入词 权重α

I 0.1

am 0.05

eating 0.15

an 0.2

apple 0.5

权重最高的是“apple”,因为它与生成词 “苹果” 的语义最相关。

根据这些权重,模型将更关注“apple”,而忽略其他词,从而生成正确的翻译。

1.1 自注意力机制

自注意力机制本质上是广义注意力 Q=K=V 时的情形。

计算注意力分数公式:

在这里先不解释公式,往下慢慢看看。

假设我们有一个输入序列X=[x1,x2,…,xn],其中 xi 是序列中的第 i 个元素的向量表示。将输入序列 X 通过三个不同的线性变换,得到三个矩阵:query矩阵 (Q)、key矩阵(K) 和value矩阵(V)。

其中,W是可学习权重矩阵。

单纯地看上面的公式就太抽象了,结合图举个例子来说明如何生成Q,K,V。下图中X=[x1,x2],其中,x1代表了"Thinking"的向量表示,x2代表了"Machines"的向量。q1是x1和WQ做矩阵乘法得到的,同样,x1和WK相乘得到k1,x1和WV相乘得到v1。

其中, W是可学习的权重矩阵。

讲到现在了,什么是“query向量”、“key向量”和“value向量”呢?  它们是用来计算注意力的。怎么计算注意力的,继续往下看。

有了q,k,v之后,就可以来计算注意力得分了。假设我们要计算“Thinking”的自注意力得分,通过对句子中的其他单词的打分,来决定在编码某个位置的单词时,对其他单词的关注程度。这些得分是通过计算query向量与各个单词的key向量的点积得到的。

计算“Thinking”和第一个位置单词(Thinking)的分数,那就是用 q1 和k1 进行点积;计算“Thinking”和第二个单词(Machines)的分数则是 q1 和 k2点积。

然后将得分除以8,为什么除以8呢?因为8是key向量的维数(64)的平方根。为什么key向量维数是64呢?是因为论文中给的是64,我们也可以设别的维数,在这里我就使用论文中设的64。为什么要除以key向量的维数的平方根?因为这能让梯度更新的过程更加稳定。然后将结果进行softmax操作。Softmax将分数归一化,使它们都是正数,加起来等于1。

其中,dk 是key向量的维度。

接下来,将每个value 向量与Softmax计算的结果相乘。这一步的核心思想是:保留希望关注的单词的value 向量值,同时削弱那些不相关单词的影响(不相关的词注意力分数很小,value 乘以一个很小的小数值,就会变得很小)。

最后一步是将这些经过加权计算的value 向量相加,从而得到自注意力机制在该位置上的输出。

在实操中,计算是以矩阵形式进行的,这是为了加快计算速度。W矩阵没有变,就是将x组合在一起了,注意x组合的维度,是在哪一维组合的,拼接出来的X,每一行代表一个单词。

1.2 多头注意力

多头注意力是自注意力机制的扩展,它允许模型在不同的子空间中学习不同的注意力表示,增强了模型关注不同位置的能力。具体来说,多头注意力机制将输入序列通过多个不同的线性变换,得到多个query、key矩阵,然后在每个子空间中计算自注意力,最后将所有子空间的输出拼接起来,并通过一个线性变换得到最终的输出。

示例:多头注意力的应用——句子相似度任务

输入:

句子 1:

"I love programming."

句子 2:

"Coding is my passion."

我们希望通过注意力机制判断两个句子是否相似。

多头注意力过程:

嵌入表示: 将句子 1 和句子 2 分别嵌入为向量:X1,X2

计算多头注意力:

头 1 可能关注单词间的语义相似性(如 “love” 和 “passion”)。

头 2 可能关注句子结构相似性。

将Q=X1,K=X2,V=X2 代入多头注意力机制。

每个注意力头会捕获句子间不同部分的相似性:

输出上下文向量: 每个注意力头生成一个上下文向量,最终拼接这些向量并映射回原始维度,得到句子 1 和句子 2 的对齐表示。

头编号 注意力分布 (句子 1 对句子 2)

头 1 love ↔ passion(高权重)

头 2 programming ↔ coding(高权重)

头 3 I ↔ is(较低权重,但句法结构相关)

每个头关注句子的不同部分,综合后判断两个句子的语义高度相关。

上面为什么 Q=X1, K=X2, V=X2?

在多头注意力机制中,我们计算查询(Query)、键(Key)和值(Value)之间的关系。在进行序列对比时,我们有两个输入序列 X1和X2,并希望通过注意力机制计算它们的相关性。

查询(Query) Q:是用来搜索信息的。这里,我们使用 X1作为查询,因为我们希望从序列 1 中的每个元素(例如,词)出发,去寻找在序列 2 中的相关信息。

键(Key)k:是用来和查询计算匹配程度的。在多头注意力中,键是表示其他序列(这里是X2)的语义内容,查询(来自X1)会与键(来自 X2)进行匹配,从而计算相关性。

值(Value)V:是最终返回的实际信息。在计算完查询和键之间的相似度后,我们基于这个相似度(权重)来加权值。在这里,我们也使用 X2作为值,因为最终我们想要的上下文信息来源于序列 2。

2. 词嵌入

词嵌入是将输入序列中的每个词(或子词)映射到一个高维向量空间的过程。例如,在中英翻译任务中,输入的中文句子“Thinking Machines”可以被表示为:

Thinking→[0.2, 0.5, -0.3, ···]

Machines→[0.1, 0.8, 0.4,···]

我→[0.2,0.5,−0.3,… ]

这些向量捕捉了词与词之间的语义关系,使得模型能够更好地理解输入序列。

3. 位置编码

由于Transformer模型完全基于注意力机制,没有像RNN那样的顺序处理能力,因此需要显式地为输入序列添加位置信息。位置编码通过正弦和余弦函数生成,例如:

其中,pos 是位置索引,i 是维度索引,d 是嵌入维度。

位置编码与词嵌入相加,得到输入序列的最终表示:

输入向量=词嵌入+位置编码码

4. 编码器

编码器的作用是将输入序列(如中文句子)转换为一系列高维向量表示。它由多个相同的层堆叠而成,每一层包含两个主要子层:多头注意力层(Self-Attention),前馈神经网络层(Feed Forward)。每一层之后还包含残差连接和层归一化,以增强模型的稳定性和训练效率。

5. 解码器

解码器的作用是将编码器的输出转换为最终的输出序列(如英文句子)。与编码器类似,解码器也由多个相同的层堆叠而成,但每一层包含三个主要子层:掩码多头注意力(Self-Attention):用于防止解码器在生成当前词时“看到”未来的词。多头注意力(Encoder-Decoder Attention)):结合编码器的输出。前馈神经网络Feed-Forward。解码器的每一层同样包含残差连接和层归一化。

在完成编码过程之后,每一个时间步,解码器会输出翻译后的一个单词。

  1. 线性层和Softmax层

线性层是一个全连接神经网络层,其作用是将解码器输出的高维向量映射到一个更大的维度空间中。这个维度通常等于目标语言的词汇表大小。例如,在中英翻译任务中,如果目标语言的词汇表包含10,000个单词,那么线性层的输出维度就是10,000。

Softmax层紧跟在线性层之后,其作用是将线性层的输出转换为概率分布。具体来说,Softmax函数会对线性层的输出进行归一化处理,使得每个值的范围在0到1之间,并且所有值的和为1。这样,每个值可以被解释为目标语言中某个单词的概率。

7. 整个模型的维度变换

我将详细讲解 Transformer 中每一步的维度变换,考虑批量大小(batch size)序列长度(seq_len)

1)输入

假设我们有一个批量大小为
batch_size
,序列长度为
seq_len
,每个词的嵌入维度是
d

  • 输入的嵌入层会将词索引转换为向量表示,得到维度为:
  • batch_size 是输入批次的大小。
  • seq_len 是每个序列的长度。
  • d 是嵌入维度(每个词的向量表示的维度)。
    之后,通常会加上位置编码,维度不变,仍为
    [batch_size,seq_len, d]

2)编码器

  • 多头自注意力
    假设输入到编码器的维度是
    [batch_size, seq_len, d]
    ,通过线性变换得到查询(Query)、键(Key)和值(Value):
    接下来,进行缩放点积注意力,计算注意力权重并加权值 V,每个头的输出维度是[batch_size,seq_len,dh],其中dh=d/h,是每个注意力头的维度。
    对于每个注意力头hi计算:
    注意力输出拼接后的维度为:
    Multi-Head Attention Output∈Rbatch_size×seq_len×d
    (拼接所有头的输出,维度为batch_size×seq_len×d)
  • 前馈神经网络
    接下来,经过一个前馈神经网络(通常包含两层线性层和一个激活函数,例如 ReLU):
    第一层:输入维度X∈Rbatch_size×seq_len×d Output1 ∈ Rbatch_size×seq_len×d2
    其中, d2 是前馈层的中间维度。
    第二层: Output2 ∈ Rbatch_size×seq_len×d
    输出维度与输入维度相同。
  • 残差连接与层归一化
    每个子层的输出都会加上输入(残差连接),然后进行层归一化(LayerNorm):
    所以每个编码器层的输出维度仍然是:
    Output Encoder∈ Rbatch_size×seq_len×d

3)解码器

解码器的结构与编码器类似,区别在于它还包括 编码器-解码器交叉注意力(Encoder-Decoder Attention),即解码器的查询来自解码器的输入,而键和值来自编码器的输出。

  • 自注意力
    与编码器一样,解码器首先使用自注意力机制,输入维度为:
    Xdec ∈ Rbatch_size×seq_len×d
    经过注意力机制后,输出维度仍然是:
    编码器-解码器交叉注意力
    解码器使用来自编码器的输出作为键和值:
    前馈神经网络
    解码器的前馈神经网络与编码器相似,经过两层线性变换后,输出维度是:
    残差连接与层归一化
    解码器的每个子层也有残差连接和层归一化,最终解码器的输出维度仍为:

4)输出层

解码器的输出通过线性层映射到词汇表的大小V(即词汇表的维度),用于生成概率分布:
这里,V是词汇表的大小,n 是序列长度,最终的输出是每个位置上各个词的概率分布。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号