正弦、余弦三角函数位置编码详解与代码实现
正弦、余弦三角函数位置编码详解与代码实现
在Transformer模型中,位置编码(Positional Encoding)是一个关键组件,用于为模型引入序列中元素的位置信息。本文将详细讲解基于正弦和余弦函数的位置编码机制,并提供完整的Python代码实现。
一、正弦、余弦三角函数位置编码讲解
在Transformer中,位置编码是为了引入位置信息,而位置编码的形式通常是一个正弦函数和一个余弦函数的组合,公式如下:
$$
PE(pos, 2i) = sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
$$
$$
PE(pos, 2i+1) = cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
$$
其中,$PE(pos, i)$ 表示位置编码矩阵中第 $pos$ 个位置,第 $i$ 个维度的值;$d_{model}$ 表示模型嵌入向量的维度;$i$ 表示位置编码矩阵中第 $i$ 个维度的值。这种位置编码方式可以引入位置信息,使得Transformer模型可以处理序列数据。
假设序列长度为4,位置编码维度为6,则位置编码矩阵如下:
其中三角函数括号中的部分可以由*号拆分成两部分,第一部分可以理解为$x$,第二部分可以理解为周期(普通的三角函数$sin(2\pi X)$的周期$T$为$2\pi$,$X$为因变量)。
按列分析:如dim0这一列周期$T$为
$X$为0~3的一个周期为定值的三角函数;
按行分析:
如pos0这一行中,周期每两个元素变化一次,$X$为递增数列;所以按行看每个pos的位置编码是一个变周期($T$)的三角函数;
二、代码实现
1. 实现位置编码矩阵
import torch
def creat_pe_absolute_sincos_embedding(n_pos_vec, dim):
assert dim % 2 == 0, "wrong dim"
position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omega = torch.arange(dim//2, dtype=torch.float)
omega /= dim/2.
omega = 1./(10000**omega)
sita = n_pos_vec[:,None] @ omega[None,:]
emb_sin = torch.sin(sita)
emb_cos = torch.cos(sita)
position_embedding[:,0::2] = emb_sin
position_embedding[:,1::2] = emb_cos
return position_embedding
2. 初始化序列长度和位置编码的维度,并计算位置编码矩阵
n_pos = 512
dim = 768
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim)
print(pe)
输出结果为:
tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 8.2843e-01, ..., 1.0000e+00,
1.0243e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.2799e-01, ..., 1.0000e+00,
2.0486e-04, 1.0000e+00],
...,
[ 6.1950e-02, 9.9808e-01, 5.3552e-01, ..., 9.9857e-01,
5.2112e-02, 9.9864e-01],
[ 8.7333e-01, 4.8714e-01, 9.9957e-01, ..., 9.9857e-01,
5.2214e-02, 9.9864e-01],
[ 8.8177e-01, -4.7168e-01, 5.8417e-01, ..., 9.9856e-01,
5.2317e-02, 9.9863e-01]])
3. 按行对位置编码矩阵进行可视化
import matplotlib.pyplot as plt
x = [i for i in range(dim)]
for index, item in enumerate(pe):
if index % 50 != 1:
continue
y = item.tolist()
plt.plot(x, y, label=f"数据 {index}")
plt.show()
以50为间隔打印,由于序列长度为512,所以可以打印出11个pos位置的曲线,下图为pos0,pos250,pos500处的位置编码曲线:
4. 按列对位置编码矩阵进行可视化
x = [i for i in range(n_pos)]
for index, item in enumerate(pe.transpose(0, 1)):
if index % 50 != 1:
continue
y = item.tolist()
plt.plot(x, y, label=f"数据 {index}")
plt.show()
以50为间隔打印,由于序列长度为768,所以可以打印出16个pos位置的曲线,下图为dim0,dim350,dim750处的位置编码曲线: