问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

正弦、余弦三角函数位置编码详解与代码实现

创作时间:
作者:
@小白创作中心

正弦、余弦三角函数位置编码详解与代码实现

引用
CSDN
1.
https://blog.csdn.net/Brilliant_liu/article/details/135033645

在Transformer模型中,位置编码(Positional Encoding)是一个关键组件,用于为模型引入序列中元素的位置信息。本文将详细讲解基于正弦和余弦函数的位置编码机制,并提供完整的Python代码实现。

一、正弦、余弦三角函数位置编码讲解

在Transformer中,位置编码是为了引入位置信息,而位置编码的形式通常是一个正弦函数和一个余弦函数的组合,公式如下:

$$
PE(pos, 2i) = sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
$$

$$
PE(pos, 2i+1) = cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
$$

其中,$PE(pos, i)$ 表示位置编码矩阵中第 $pos$ 个位置,第 $i$ 个维度的值;$d_{model}$ 表示模型嵌入向量的维度;$i$ 表示位置编码矩阵中第 $i$ 个维度的值。这种位置编码方式可以引入位置信息,使得Transformer模型可以处理序列数据。

假设序列长度为4,位置编码维度为6,则位置编码矩阵如下:

其中三角函数括号中的部分可以由*号拆分成两部分,第一部分可以理解为$x$,第二部分可以理解为周期(普通的三角函数$sin(2\pi X)$的周期$T$为$2\pi$,$X$为因变量)。

按列分析:如dim0这一列周期$T$为
$X$为0~3的一个周期为定值的三角函数;

按行分析
如pos0这一行中,周期每两个元素变化一次,$X$为递增数列;所以按行看每个pos的位置编码是一个变周期($T$)的三角函数;

二、代码实现

1. 实现位置编码矩阵

import torch

def creat_pe_absolute_sincos_embedding(n_pos_vec, dim):
    assert dim % 2 == 0, "wrong dim"
    position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
    omega = torch.arange(dim//2, dtype=torch.float)
    omega /= dim/2.
    omega = 1./(10000**omega)
    sita = n_pos_vec[:,None] @ omega[None,:]
    emb_sin = torch.sin(sita)
    emb_cos = torch.cos(sita)
    position_embedding[:,0::2] = emb_sin
    position_embedding[:,1::2] = emb_cos
    return position_embedding

2. 初始化序列长度和位置编码的维度,并计算位置编码矩阵

n_pos = 512
dim = 768
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim)
print(pe)

输出结果为:

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2843e-01,  ...,  1.0000e+00,
          1.0243e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.2799e-01,  ...,  1.0000e+00,
          2.0486e-04,  1.0000e+00],
        ...,
        [ 6.1950e-02,  9.9808e-01,  5.3552e-01,  ...,  9.9857e-01,
          5.2112e-02,  9.9864e-01],
        [ 8.7333e-01,  4.8714e-01,  9.9957e-01,  ...,  9.9857e-01,
          5.2214e-02,  9.9864e-01],
        [ 8.8177e-01, -4.7168e-01,  5.8417e-01,  ...,  9.9856e-01,
          5.2317e-02,  9.9863e-01]])

3. 按行对位置编码矩阵进行可视化

import matplotlib.pyplot as plt

x = [i for i in range(dim)]
for index, item in enumerate(pe):
    if index % 50 != 1:
        continue
    y = item.tolist()
    plt.plot(x, y, label=f"数据 {index}")
    plt.show()

以50为间隔打印,由于序列长度为512,所以可以打印出11个pos位置的曲线,下图为pos0,pos250,pos500处的位置编码曲线:

4. 按列对位置编码矩阵进行可视化

x = [i for i in range(n_pos)]
for index, item in enumerate(pe.transpose(0, 1)):
    if index % 50 != 1:
        continue
    y = item.tolist()
    plt.plot(x, y, label=f"数据 {index}")
    plt.show()

以50为间隔打印,由于序列长度为768,所以可以打印出16个pos位置的曲线,下图为dim0,dim350,dim750处的位置编码曲线:

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号