Transformer位置编码(Position Embedding)详解
创作时间:
作者:
@小白创作中心
Transformer位置编码(Position Embedding)详解
引用
CSDN
1.
https://blog.csdn.net/weixin_43135178/article/details/136853202
本文主要介绍Transformer中的位置编码(Position Embedding)概念及其在不同场景下的应用。通过对比分析NLP、Vision Transformer、Swin Transformer和Masked AutoEncoder中的位置编码机制,帮助读者深入理解Transformer模型中的位置编码原理。
NLP Transformer中的位置编码
在NLP领域中,Transformer使用的是1D的绝对位置编码,通过sin+cos函数将每个token编码为一个向量。这种编码方式被称为硬编码。
为什么需要位置编码?
位置编码的主要目的是为序列中的每个单词(token)分配一个唯一的表示。使用单个数字(例如索引值)来表示位置存在以下问题:
- 对于长序列,索引的幅度可能会变大
- 如果将索引值规范化为介于0和1之间,可能会为可变长度序列带来问题,因为它们的规范化方式不同
因此,Transformers将每个单词的位置都映射到一个向量。一个句子的位置编码是一个矩阵,其中矩阵的每一行代表序列中的一个token与其位置信息相加。
位置编码计算公式
1D绝对sin-cos常量位置编码的计算公式如下:
- k:token在输入序列中的位置,0<=k<=L-1
- d: 位置编码嵌入空间的维度
- P(k,j): 位置函数,用于映射输入序列中k处的元素到位置矩阵的(k,j)处
- n:用户定义的标量,由Attention Is All You Need的作者设置为10,000
- i: 用于映射到列索引,0<=i<d/2,单个值i映射到正弦和余弦函数
位置编码计算示例
以n=100和d=4为例,短语"I am a robot"的位置编码计算过程如下:
import torch
def create_1d_absolute_sincos_embedding(n_pos_vec, dim):
assert dim % 2 == 0, "dim must be even"
position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)
omege = torch.arange(dim // 2, dtype=torch.float)
omege /= dim / 2.
omege = 1./(100 ** omege)
out = n_pos_vec[:, None] @ omege[None, :]
emb_sin = torch.sin(out)
emb_cos = torch.cos(out)
position_embedding[:, 0::2] = emb_sin
position_embedding[:, 1::2] = emb_cos
return position_embedding
if __name__ == "__main__":
n_pos = 4
dim = 4
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
position_embedding = create_1d_absolute_sincos_embedding(n_pos_vec, dim)
print(position_embedding.shape)
print(position_embedding)
输出结果为:
torch.Size([4, 4])
tensor([[0.0000, 1.0000, 0.0000, 1.0000],
[0.0175, 0.9998, 0.0349, 0.9994],
[0.0349, 0.9994, 0.0698, 0.9976],
[0.0523, 0.9986, 0.1045, 0.9945]])
这表明每个单词都被映射到了一个4维的张量中,且每个单词对应的张量是不同的。
Vision Transformer中的位置编码
Vision Transformer使用的是1D的绝对位置编码,但这个位置编码是可训练的(可学习的)。这种软编码方式虽然增加了模型的参数量,但能够更好地适应图像数据的特性。
代码实现
import torch
import torch.nn as nn
def create_1d_absolute_trainable_embedding(n_pos_vec, dim):
position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
nn.init.constant_(position_embedding.weight, 0)
return position_embedding
if __name__ == "__main__":
n_pos = 3
dim = 4
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
position_embedding = create_1d_absolute_trainable_embedding(n_pos_vec, dim)
print(position_embedding)
Swin Transformer和Masked AutoEncoder中的位置编码
Swin Transformer和Masked AutoEncoder中的位置编码机制较为复杂,具体实现细节可以参考相关论文和代码实现。
参考资料
热门推荐
世界最长寿老人去世,117岁,称“远离有害的人”能活得更久
“春化补肥”很关键,兰花做好这一步,花剑蹭蹭长,花儿开不停!
掌握这10个家庭亲子游戏,有望提升孩子的专注力
16个促进运动和智力发育的亲子游戏!让宝宝快乐成长
法院判决离婚后户籍会体现吗?复婚和再婚需要哪些证件?
香港银行卡的用处:为您的国际业务提供便利与安全
国企求职指南:从招聘信息到面试准备的全方位攻略
美国人开始自己养鸡 鸡蛋价格飙升应对短缺
云南坝美深度游:完整旅行攻略与日程安排指南
代驾服务的基本条件与要求,确保安全出行的重要性与未来发展趋势
考研国家线大降:心态调整、原因分析与备考策略调整
中国大学100强最新出炉:首都北京独占28所,江苏16所,广东8所
房屋办产权需要夫妻签字吗
Steam在线峰值大比拼:非DEI游戏领跑,玩家偏好悄然变化?
差额征税:差额开票还是全额开票?
无糖可乐糖尿病能喝吗
大学绩点计算规则详解
读《资治通鉴》笔记之一百九
避免孩子得抑郁症,家长要怎么做?
破除病耻感是青少年抑郁防治的第一步
8岁孩子的身高体重标准,你了解吗?
直接串连联想法:构建知识链的妙法
酒驾和醉驾的处罚分别是什么
解密上财考研复试,关键要素与制胜策略
课题申报:脱颖而出的策略与技巧
SSID是WiFi名称吗?路由器SSID广播怎么打开?
八方宾客踏歌来!云南文山丘北县壮族“三月三”祭竜节盛大开幕
清气化痰丸与清肺化痰丸的区别
金鱼养护全指南:如何在养金鱼时避免水上产生泡泡?
2025医疗康养空间健康照明研讨会:专家共话“医养之光”