问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

大模型解决长文本输入问题

创作时间:
作者:
@小白创作中心

大模型解决长文本输入问题

引用
CSDN
1.
https://m.blog.csdn.net/wlxsp/article/details/143452372

大模型在处理长文本输入时面临诸多挑战,其中最核心的是注意力机制的计算复杂度问题。本文将深入探讨这一问题的成因,并介绍几种主流的解决方案,帮助读者理解大模型如何突破长文本处理的限制。

前言

近期,Kimi模型因其宣称支持200万token的输入而备受关注。然而,这一说法存在误导性:实际上,Kimi所依赖的底层模型Moonshot-V1仅支持128k token的处理。这一现象引发我们思考:为什么大模型难以处理长上下文输入?又有哪些技术手段可以解决这一问题?

为什么难以增长token?

原因1:Attention机制

大模型的底层架构是Transformer,其核心是Attention机制。Self-Attention处理n个token的复杂度为O(n^2),这意味着序列长度的增加会显著提升计算复杂度。

原因2:位置编码

虽然位置编码(如ROPE)理论上可以处理无限长的token序列,但在实践中仍面临挑战。过长的序列会导致模型训练效果不佳,需要更精细的优化策略。

解决方法1:稀疏Transformer

稀疏Transformer通过减少每个token需要关注的其他token数量,来降低计算复杂度。以下是几种主要的稀疏注意力机制:

Full Self Attention

最普通的Transformer与每个元素都进行Attention计算。例如,在“我爱中国美食”中,“我”会与所有token进行计算得出结果。

时间空间复杂度均为O(n^2)。

Atrous Self Attention

要求每个token之间的注意力是不连续的,它的注意力只与距离为nk的元素的值相关。例如,在“我爱你的兄弟们”中,k=2时,“我”会与“你”、“兄”、“们”这三个token做计算。

时间空间复杂度都变成了O(n log n)。

Local Self Attention

要求每个token只和自己前后距离为k的token做计算。例如,在“我爱你的兄弟们”中,k=2时,“的”会与“爱”、“你”、“兄”、“弟”做计算。

时间空间复杂度都为O(n)。

Stride Sparse Self Attention

将上述两种方式进行了融合,结合了各自的优点。每个token与距离为m的和距离为nk的都进行计算。

时间空间复杂度:O(n log n)。

Fix Sparse Self Attention

对token进行分组后在组内做全注意力计算+对特定位置的元素固定做注意力计算。

同样保证了局部紧密相关和远程稀疏相关特性。

Sparse Softmax

主要对softmax进行了修改,只保留最大的k个类别,剩下的都置为0。公式如下:

稀疏化的关键作用在于缓解 Softmax 过度学习问题。假设目标类别的分数最大(即S_i = max(S)),则原始交叉熵公式可表示为:

进一步可得不等式:

设当前交叉熵值为L,当S_i = S_j时,解得:

为了使损失降至 0.69,最大的 logit 与最小的 logit 之间的差距必须超过 log(n-1)。当 n 较大时,这对分类问题来说是一个不必要的过大间隔。实际上,我们只需要目标类的 logit 略高于非目标类即可,而不必达到 log(n-1) 这么大的差距,因此常规的交叉熵容易因过度学习而导致过拟合。

一个简单的伪代码实现:

import torch

def sparse_softmax(preds, k):
    # 稀疏化 Softmax 函数:仅保留每行中最大的 k 个元素用于 Softmax 计算,其余置零。
    
    # 获取每行的前 k 个最大值及其索引
    vals, indices = torch.topk(preds, k, dim=1)
    
    sparse = torch.zeros_like(preds)
    
    # 将前 k 个最大值填充到对应位置
    sparse.scatter_(1, indices, vals)
    exp_sparse = torch.exp(sparse)
    sum_exp_sparse = exp_sparse.sum(dim=1, keepdim=True)
    output = exp_sparse / (sum_exp_sparse + 1e-10)
    
    return output

解决方法2:位置编码

位置编码(如ROPE)通过引入周期性函数来编码位置信息,使得模型能够处理任意长度的序列。这种方法在实践中已被证明是有效的。

解决方法3:Multi_QueryAttention

多查询注意力机制通过减少查询向量的数量来降低计算复杂度,但这一方法在长序列处理中的效果仍有待进一步研究。

解决方法4:MOE技术

混合专家(MOE)技术通过引入多个专家模型来处理不同类型的输入,从而提高模型的效率和性能。这一方法在长序列处理中展现出潜力。

本质上来说,解决问题的方法还得是Linear Attention,因为n方的代价太大了,得有更好的解决方法来处理Transformer。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号