大模型解决长文本输入问题
大模型解决长文本输入问题
大模型在处理长文本输入时面临诸多挑战,其中最核心的是注意力机制的计算复杂度问题。本文将深入探讨这一问题的成因,并介绍几种主流的解决方案,帮助读者理解大模型如何突破长文本处理的限制。
前言
近期,Kimi模型因其宣称支持200万token的输入而备受关注。然而,这一说法存在误导性:实际上,Kimi所依赖的底层模型Moonshot-V1仅支持128k token的处理。这一现象引发我们思考:为什么大模型难以处理长上下文输入?又有哪些技术手段可以解决这一问题?
为什么难以增长token?
原因1:Attention机制
大模型的底层架构是Transformer,其核心是Attention机制。Self-Attention处理n个token的复杂度为O(n^2),这意味着序列长度的增加会显著提升计算复杂度。
原因2:位置编码
虽然位置编码(如ROPE)理论上可以处理无限长的token序列,但在实践中仍面临挑战。过长的序列会导致模型训练效果不佳,需要更精细的优化策略。
解决方法1:稀疏Transformer
稀疏Transformer通过减少每个token需要关注的其他token数量,来降低计算复杂度。以下是几种主要的稀疏注意力机制:
Full Self Attention
最普通的Transformer与每个元素都进行Attention计算。例如,在“我爱中国美食”中,“我”会与所有token进行计算得出结果。
时间空间复杂度均为O(n^2)。
Atrous Self Attention
要求每个token之间的注意力是不连续的,它的注意力只与距离为nk的元素的值相关。例如,在“我爱你的兄弟们”中,k=2时,“我”会与“你”、“兄”、“们”这三个token做计算。
时间空间复杂度都变成了O(n log n)。
Local Self Attention
要求每个token只和自己前后距离为k的token做计算。例如,在“我爱你的兄弟们”中,k=2时,“的”会与“爱”、“你”、“兄”、“弟”做计算。
时间空间复杂度都为O(n)。
Stride Sparse Self Attention
将上述两种方式进行了融合,结合了各自的优点。每个token与距离为m的和距离为nk的都进行计算。
时间空间复杂度:O(n log n)。
Fix Sparse Self Attention
对token进行分组后在组内做全注意力计算+对特定位置的元素固定做注意力计算。
同样保证了局部紧密相关和远程稀疏相关特性。
Sparse Softmax
主要对softmax进行了修改,只保留最大的k个类别,剩下的都置为0。公式如下:
稀疏化的关键作用在于缓解 Softmax 过度学习问题。假设目标类别的分数最大(即S_i = max(S)),则原始交叉熵公式可表示为:
进一步可得不等式:
设当前交叉熵值为L,当S_i = S_j时,解得:
为了使损失降至 0.69,最大的 logit 与最小的 logit 之间的差距必须超过 log(n-1)。当 n 较大时,这对分类问题来说是一个不必要的过大间隔。实际上,我们只需要目标类的 logit 略高于非目标类即可,而不必达到 log(n-1) 这么大的差距,因此常规的交叉熵容易因过度学习而导致过拟合。
一个简单的伪代码实现:
import torch
def sparse_softmax(preds, k):
# 稀疏化 Softmax 函数:仅保留每行中最大的 k 个元素用于 Softmax 计算,其余置零。
# 获取每行的前 k 个最大值及其索引
vals, indices = torch.topk(preds, k, dim=1)
sparse = torch.zeros_like(preds)
# 将前 k 个最大值填充到对应位置
sparse.scatter_(1, indices, vals)
exp_sparse = torch.exp(sparse)
sum_exp_sparse = exp_sparse.sum(dim=1, keepdim=True)
output = exp_sparse / (sum_exp_sparse + 1e-10)
return output
解决方法2:位置编码
位置编码(如ROPE)通过引入周期性函数来编码位置信息,使得模型能够处理任意长度的序列。这种方法在实践中已被证明是有效的。
解决方法3:Multi_QueryAttention
多查询注意力机制通过减少查询向量的数量来降低计算复杂度,但这一方法在长序列处理中的效果仍有待进一步研究。
解决方法4:MOE技术
混合专家(MOE)技术通过引入多个专家模型来处理不同类型的输入,从而提高模型的效率和性能。这一方法在长序列处理中展现出潜力。
本质上来说,解决问题的方法还得是Linear Attention,因为n方的代价太大了,得有更好的解决方法来处理Transformer。