Mamba技术背景详解:从RNN到Mamba一文搞定!
Mamba技术背景详解:从RNN到Mamba一文搞定!
Mamba技术背景详解:从RNN到Mamba
本文将详细介绍Mamba技术的背景和原理。从RNN和Transformer的优缺点入手,引出SSM(状态空间模型)的概念,并详细解释了Mamba如何通过选择性扫描算法和硬件感知算法来优化模型性能。
Transformer到Mamba
Transformer缺陷:
一次性矩阵每个token进行比较(支持并行化)
推理缺陷:生成下一个token任务中,要算所有token的注意力(L^2)
RNN解决:
RNN只考虑之前隐藏状态和当前输入,防止重新计算所有先前状态
但RNN会遗忘信息(不然就不会有Transformer出现了)
RNN是顺序循环——>训练不能并行
其实这也不能说是谁解决谁缺陷吧,毕竟lstm和transformer的出现就是为了解决RNN的遗忘的总之,RNN推理速度快,但不能并行,Transformer反之。
❓能否以某种方式找到一种像 Transformer 这样并行训练的架构,同时仍然执行随序列长度线性扩展的推理?
SSM(State Space Model)
State Space:
SSM:预测下一个状态
- 输入序列x(t) —(在迷宫中向左和向下移动)
- 潜在状态h(t) —(距离和 x/y 坐标)
- 预测输出序列y(t) —(再次移动以更快到达出口)
然而,它不使用离散序列(如向左移动一次),而是将连续序列作为输入并预测输出序列。
A,B,C——>SSM
D——>skip connection(提供从输入到输出的直接信号)
SSM——>连续
连续信号到离散信号
Zero-order hold technique:
有了连续的输入信号,我们可以生成连续的输出,并且仅根据输入的时间步长对值进行采样,采样值就是我们的离散输出。
循环表征
类似于RNN:
卷积
图像:过滤器(卷积核)导出聚合特征
文本:一维
三个表征
- 连续
- 循环
- 卷积
训练期间,使用可以并行化的卷积表示;推理期间,使用高效的循环表示
对于每个token,A、B、C都是相同的,不管什么顺序
矩阵A:根据之前状态构建new state(produces hidden size产生一个隐藏状态来记住其历史)
如何保留上下文大小创建A?——Hungry Hippo(根据勒让德多项式Legendre polynomial的系数实现)
Mamba:A Selective SSM
优点:用于循环与卷积。它可以通过构建 HiPPO 矩阵来处理长文本序列并有效存储内存。
序列的结构化状态空间由三部分组成:
- 用于创建Discretization for creating 循环recurrent和 and 卷积convolution表示的离散化 representations
- HiPPO 用于处理HiPPO for handling 远程依赖关系long-range dependencies
- 状态空间模型State Space Models
Mamba
主要贡献:
- 选择性扫描算法selective scan algorithm——允许模型过滤相关信息
- 硬件感知算法hardware-aware algorithm——允许通过并行扫描、内核融合和重新计算有效存储结构
拟解决的问题:在关注或忽略特定输入的能力上表现不佳
两个任务:选择性复制selective copying、感应头induction heads
SSM:
选择性复制selective copying:复制部分输入并按顺序输出
SSM表现不好的原因:由于 SSM 是时不变的,它无法从其历史记录中选择要调用的先前token(矩阵B不变,A、C也不变)
SSM 的循环表示创建了一个非常高效的小状态,因为它压缩了整个历史记录。然而,与不压缩历史记录(通过注意力矩阵)的 Transformer 模型相比,它的功能要弱得多,所以要有选择的保留信息。
有选择地保留信息
曼巴的目标是两全其美。一个像 Transformer 的状态一样强大的小状态:
有选择地将数据压缩到状态中
为了有选择地压缩信息,我们需要参数依赖于输入。为此,我们首先探讨训练期间 SSM 中输入和输出的维度:
在结构化状态空间模型 (S4) 中ND是静态且不会改变
mamba通过合并输入的序列长度和批量大小来使矩阵B和C, 甚至step sizeΔ取决于输入:
对于每个输入标记,有不同的B,C矩阵,可以解决内容感知问题
⚠️矩阵A保持不变,因为我们希望状态本身保持静态,但它受到影响的方式(矩阵B,C)是动态的
他们一起选择将哪些内容保留在隐藏状态以及忽略哪些内容,因为它们现在依赖于输入
较小的step sizeΔ会导致忽略特定单词,而是更多地使用先前的上下文,而较大的step sizeΔ会更多地关注输入单词而不是上下文:
扫描操作 The Scan Operation
由于这些矩阵现在是动态的,因此无法使用卷积表示来计算它们,因为它假定了一个固定内核。我们只能使用循环表示,而失去了卷积提供的并行性。
为了实现并行化,让我们探讨如何使用循环计算输出:
扫描操作:每个状态都是前一个状态(乘A)加上当前输入(乘B的总和)can easily be calculated with a for loop.
相反,并行化似乎是不可能的,因为只有在我们拥有前一个状态的情况下才能计算每个状态。Mamba通过 并行扫描parallel scan算法:
它假设我们执行操作的顺序与关联属性无关。因此,我们可以分段计算序列并迭代地组合它们:
动态矩阵B,C以及扫描算法一起创建,选择性扫描算法selective scan algorithm 来表示使用循环表示的动态和快速本质。
硬件感知算法 Hardware-aware Algorithm
最新 GPU 的一个缺点是其小型但高效的 SRAM 与大型但效率稍低的 DRAM 之间的传输 (IO) 速度有限。在 SRAM 和 DRAM 之间频繁复制信息成为瓶颈。
Mamba 与 Flash Attention 一样,试图限制我们需要从 DRAM 到 SRAM 的次数。
通过kernel fusion实现,允许模型防止写入中间结果并连续执行计算直到完成。
以下内容被融合到一个内核中:
- 步长为Δ 的离散化步长
- 选择性扫描算法
- 与C相乘
硬件感知算法的最后一部分是重新计算——recomputation
中间状态不会被保存,但对于向后传递计算梯度是必需的。相反,mamba在向后传递期间重新计算这些中间状态。尽管这看起来效率低下,但它比从相对较慢的 DRAM 读取所有这些中间状态的成本要低得多。
Mamba block
selective SSM:
selective SSM 可以作为一个块来实现,就像解码器块中表示自注意力一样*。我们可以堆叠多个 Mamba 块并将它们的输出用作下一个 Mamba 块的输入:
它从线性投影开始,以扩展输入嵌入。然后,在selective SSM 之前应用卷积以防止独立计算token。
总而言之,selective SSM 有以下属性:
- Recurrent SSM 被离散化创建
- HiPPO对矩阵A进行初始化以捕获长程依赖性
- 选择性扫描算法选择性压缩信息
- 加速计算的硬件感知算法
Result:
来源:https://maartengrootendorst.substack.com/p/a-visual-guide-to-mamba-and-state