不是RNN的锅!清华团队深入分析长上下文建模中的状态崩溃,Mamba作者点赞
不是RNN的锅!清华团队深入分析长上下文建模中的状态崩溃,Mamba作者点赞
【新智元导读】RNN模型在长上下文中的表现一直备受关注。近日,来自清华大学的研究团队对这一问题进行了深入的实验分析,揭示了RNN模型在长上下文中面临的主要挑战,并提出了有效的解决方案。这项研究得到了Mamba作者Albert Gu的高度评价。
与Transformer相比,RNN模型的一大优势是应对长序列的能力。例如,Mamba模型的内部状态大小始终保持不变,计算随序列长度线性增长,这使得它在处理长序列时具有较高的效率。
然而,实际情况是,目前的这些RNN模型在长上下文中的有效性并不能令人满意。为了解决这一问题,来自清华大学的研究团队进行了深入的实验研究。
研究发现:两个主要问题
研究团队发现,Mamba这类RNN模型在长上下文中主要面临两个问题:
- 无法推断比训练长度更长的输入:这是由于较短的训练数据导致了循环状态过拟合。
- 内存容量的上限:由于模型无法有效遗忘很久以前的信息,导致新的信息存不进来了。
解决方案
针对上述问题,研究团队提出了三种解决方案:
Method 1: Forget More and Remember Less
通过增加状态衰减量(忘记更多)或减少输入信息的数量(记住更少)来减少SC,作者选择干预Bt和αt(分别控制输入强度和内存衰减强度)。Method 2: State Normalization
在每次更新后对状态进行归一化,以确保状态的范数始终低于阈值。这种方法会将模型转换为非线性RNN,无法以与原始模型相同的方式并行化,预填充速度要慢得多。Method 3: Sliding Window by State Difference
利用状态ht可以写为加权和的形式,来模拟滑动窗口机制,无需在每一步都从窗口的开头重新处理。此方法适用于所有可以写成加权和的RNN,包括RWKV 5和6、RetNet、GLA等。尽管会使生成的计算和内存成本翻倍,但仍然是一个可以接受的权衡,因为RNN的生成成本比Transformer低很多。
实验结果
研究团队通过实验验证了这些解决方案的有效性。实验数据选择了RedPajama-V2数据集,这是一个从CommonCrawl中提取的30T token的开放数据集,进行了去重以确保数据质量。
实验结果表明,所有提出的方法都成功地抑制了状态崩溃(SC),使模型能够泛化到超过64K个token。其中,状态归一化方法在较短序列上的性能明显低于其他方法,这可能是因为归一化折叠状态会改变heads之间的规范比率,破坏了学习机制。
Mamba作者点赞
这项研究得到了Mamba作者Albert Gu的高度评价。他认为这是一篇非常棒的论文,揭示了状态空间模型(SSM)的状态容量和长上下文能力的重要见解。特别是研究中发现的临界值K与状态大小M呈线性关系,表明每个token可能存在某种固有的信息含量,这一发现具有重要的理论意义。
Albert Gu还指出,过分担心循环模型的长度泛化问题可能是一个误区。实际上,我们不需要设计新机制或特殊的缓解措施,只需要在更长的序列上训练模型,就能获得更好的泛化效果。
结论
这项研究不仅揭示了RNN模型在长上下文建模中面临的具体问题,还提出了有效的解决方案,为未来的研究和应用提供了重要的参考。正如Albert Gu所说,要让你的Mamba吃得饱饱的,它就能发挥出最佳状态!
论文地址:https://arxiv.org/pdf/2410.07145v1
参考资料: