FlashAttention 原理之 softmax 分块计算
FlashAttention 原理之 softmax 分块计算
FlashAttention是一种优化Transformer模型中注意力机制的算法,其中Softmax分块计算是其核心组成部分之一。本文将详细介绍Softmax函数的基本原理以及如何通过分块计算来优化其性能。
标准 Softmax
Softmax函数(也称为归一化指数函数)是一个将向量转换成概率分布的函数。对于输入向量 x,softmax函数将其转换为一个概率分布向量,其中每个元素的值在 (0,1) 之间,且所有元素之和为 1。
$$
softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}
$$
归一化 Softmax
定义向量 x 的最大值:
$$
m(x) := \max_i x_i
$$
这里,$x = [x_1, x_2, \dots, x_B]$ 表示一个有 B 个分量的向量(例如,模型输出的对各类的打分)。$m(x)$ 则是向量 x 中所有分量的最大值。
定义指数向量:
$$
f(x) := \bigl[e^{,x_1 - m(x)},, e^{,x_2 - m(x)},,\dots,, e^{,x_B - m(x)}\bigr]
$$
这里做了一个“减最大值”的操作,即把每个 $x_i$ 都减去整个向量的最大分量 $m(x)$,然后取指数。这样做的好处是数值更稳定:当 $x_i$ 很大时,直接算 $e^{x_i}$ 容易导致溢出;但减去最大值以后,指数部分变为 $x_i - m(x)$(一个相对较小的或非正的数),从而避免数值爆炸。
计算指数向量的和:
$$
\ell(x) := \sum_i f(x)_i
$$
这里把向量 $f(x)$ 的各个分量加起来得到标量 $\ell(x)$。
计算最终的 Softmax 输出:
$$
softmax(x) := \frac{f(x)}{\ell(x)}
$$
把每个分量 $f(x)_i$ 除以 $\ell(x)$ 后,就得到了标准的 Softmax 输出向量。
$$
softmax(x)_i = \frac{e^{,x_i - m(x)}}{\sum_j e^{,x_j - m(x)}}
$$
由于每一项都经过指数函数且被总和归一化,它满足所有分量都非负且所有分量之和为 1,因此是一个有效的概率分布。
分块 Softmax
假设我们有两个同维度向量 $x^{(1)}$ 和 $x^{(2)} \in \mathbb{R}^B$,把它们拼接(concatenate)成:
$$
x = \bigl[x^{(1)},,x^{(2)}\bigr] \in \mathbb{R}^{2B}
$$
下面的公式说明,如何在不重复完整计算的情况下,用“各自的部分计算结果”组合成拼接后向量的 Softmax。先给出它的步骤,再解释其意义和好处:
最大值 $m(x)$ 的分块计算
定义单个向量的最大值:
对于 $x^{(1)} \in \mathbb{R}^B$,我们先定义:
$$
m\bigl(x^{(1)}\bigr) ;=; \max_i \Bigl(x^{(1)}_i\Bigr), \quad m\bigl(x^{(2)}\bigr) ;=; \max_i \Bigl(x^{(2)}_i\Bigr)
$$
定义拼接向量的最大值:
由于 $x$ 是把 $x^{(1)}$ 和 $x^{(2)}$ 拼到一起,那么:
$$
m(x) ;=; m\bigl([x^{(1)},, x^{(2)}]\bigr) ;=; \max\bigl(m(x^{(1)}), ;m(x^{(2)})\bigr)
$$
这样就不需要对拼接后的 $x$ 再扫描一次去找最大值,而是只要比较两个子向量各自的最大值即可。
“指数向量” $f(x)$ 的分块计算
Recall:
$$
f(x) ;=; \Bigl[e^{,x_1 - m(x)},; e^{,x_2 - m(x)},;\dots,; e^{,x_{2B} - m(x)}\Bigr]
$$
由于 $x$ 拆成了两块 $x^{(1)}$、$x^{(2)}$,我们分别对每块计算其对应的“指数向量”:
$$
f\bigl(x^{(1)}\bigr) \quad\text{和}\quad f\bigl(x^{(2)}\bigr)
$$
然后拼起来即可。但要记住,每一块真正要减去的“中心化值”是整个 $x$ 的最大值 $m(x)$,因此它们之间会出现一个额外的“补偿系数”:
$$
f(x) ;=; \Bigl[ e^{,m(x^{(1)}) - m(x)},\cdot f\bigl(x^{(1)}\bigr), ;; e^{,m(x^{(2)}) - m(x)},\cdot f\bigl(x^{(2)}\bigr) \Bigr]
$$
直观上看,如果某一块(比如 $x^{(1)}$)的最大元素是整个拼接向量的最大元素,那么它带来的指数因子就会是 $e^{,m(x^{(1)}) - m(x)} = e^0 = 1$。而另一块若不是最大的,就会额外乘上一个小于 1 的因子。
归一化项 $\ell(x)$ 的分块计算
Softmax 要把向量的指数项归一化到和为 1,所以我们需要计算:
$$
\ell(x) ;=; \sum_{i=1}^{2B} e^{,x_i - m(x)}
$$
利用分块思想,可以分成两段求和,再用与上一步相同的补偿系数连接起来:
$$
\ell(x) ;=; \underbrace{e^{,m(x^{(1)}) - m(x)},\ell\bigl(x^{(1)}\bigr)}{\text{第1块贡献}} ;+; \underbrace{e^{,m(x^{(2)}) - m(x)},\ell\bigl(x^{(2)}\bigr)}{\text{第2块贡献}}
$$
同理,也只需要各块自己内部的和,再用一个相对的尺度因子即可。
最大值 $\mathrm{softmax}(x)$ 的分块形式
把上面得到的 $f(x)$ 和 $\ell(x)$ 带入到:
$$
\mathrm{softmax}(x) ;=; \frac{f(x)}{\ell(x)}
$$
就得到在分块后的 Softmax 形式:
$$
\mathrm{softmax}(x) ;=; \frac{ \Bigl[ e^{,m(x^{(1)}) - m(x)} ,f\bigl(x^{(1)}\bigr), ;; e^{,m(x^{(2)}) - m(x)} ,f\bigl(x^{(2)}\bigr) \Bigr] }{ e^{,m(x^{(1)}) - m(x)},\ell\bigl(x^{(1)}\bigr) ;+; e^{,m(x^{(2)}) - m(x)},\ell\bigl(x^{(2)}\bigr) }
$$
为什么这样做?
- 数值稳定性
跟单向量计算 Softmax 类似,这里也要减去整段向量的最大值 $m(x)$,以避免 $e^z$ 里的 $z$ 太大或太小导致溢出/下溢。
- 减少重复计算
如果我们已经知道各块各自的 $\max$ 值和求和 $\ell(x^{(k)})$,那就无需把 $x$ 整体重新扫描、求最大值、求指数和,总结出公式即可快速拼成拼接后向量的软最大值。
- 方便分布式或分块处理
在实际系统里,$x^{(1)}$ 和 $x^{(2)}$ 可能来自不同子网络或不同设备。这种分块计算可以让每一块先在本地完成自己的 Softmax 部分计算,最后再做一次简短的组合归一化即可。