Skip to content
传衡博客
返回

FlashAttention2 原理与数值推导

参考资料
  1. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  2. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  3. Dao-AILab/flash-attention
  4. Online Softmax by Hand
  5. How Flash Attention Works

FlashAttention2 把注意力机制的显存复杂度从 O(n²) 降到 O(n),在 A100 GPU 上比 v1 快约 2 倍[1]。关键在于两个优化:把大矩阵切块分块计算(tiling),和在线更新 softmax 的分母和最大值(online softmax)。用 4×4 的矩阵算一遍,就能看懂每一步中间状态的更新。

标准 Attention 的显存瓶颈

标准 attention 的计算流程是:先算完整注意力矩阵 S=QKTS = QK^T,存起来,再算 softmax,最后乘以 VV。对于一个序列长度为 NN、hidden 维度为 dd 的 attention 层[2]

N=4096N = 4096 时,SS 矩阵就要占 2×4096233.62 \times 4096^2 \approx 33.6 MB。一个 Transformer block 有多个 attention 层,加上中间激活值,总显存很快爆炸。

核心问题是:整个 SS 矩阵必须一次性算完、存起来,后面才能用。但计算 SS 的每个元素其实只需要 QQ 的某一行和 KK 的某一列,完全可以边算边丢弃中间结果。

GPU 的内存层级

GPU 的存储不是一层的,而是分了几级:

层级容量带宽延迟作用
HBM80 GB (A100)~2 TB/s大容量存储
SRAM20 MB/SM~19 TB/s快速缓存

HBM 是”大而慢”,SRAM 是”小而快”。标准 attention 把中间结果都写回 HBM,频繁的数据搬运成了瓶颈。FlashAttention2 的核心思想是把数据搬运降到最少:把 QQKKVV 切成小块,每次只把需要的一块搬到 SRAM,算完再写回 HBM。

Tiling:分块计算

假设 Q,K,VQ, K, V 都是 N×dN \times d 的矩阵,把它们切成 Bc×BrB_c \times B_r 的小块:

计算时,对于每个 query block QiQ_i,依次加载每个 key block KjK_j 和 value block VjV_j,在 SRAM 里计算局部 attention:

OiOi+softmax(QiKjT/d)VjO_i \leftarrow O_i + \text{softmax}(Q_i K_j^T / \sqrt{d}) V_j

累加完所有 Kj,VjK_j, V_j 后,OiO_i 就是完整的输出。

关键在于:每个 KjK_j 只加载一次,算完就丢,不会占用 O(n²) 显存。SRAM 只需要同时存 QiQ_iKjK_jVjV_j 和中间结果,容量就够了。

Online Softmax:递推更新

标准 softmax 先算完整分数,再找最大值和求和分母:

softmax(x)i=exp(xi)jexp(xj)\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

但在线场景下数据是一批批来的,没法事先知道最大值和分母。FlashAttention2 用递推公式更新:

假设已经处理了前 tt 个元素,保存了:

来了第 t+1t+1 个元素 xt+1x_{t+1},更新公式:

mt+1=max(mt,xt+1)lt+1=ltexp(mtmt+1)+exp(xt+1mt+1)\begin{aligned} m_{t+1} &= \max(m_t, x_{t+1}) \\ l_{t+1} &= l_t \cdot \exp(m_t - m_{t+1}) + \exp(x_{t+1} - m_{t+1}) \end{aligned}

输出时,第 ii 个元素的 softmax 值是:

softmax(x)i=exp(ximT)lT\text{softmax}(x)_i = \frac{\exp(x_i - m_T)}{l_T}

其中 TT 是总元素数。

直观理解:新元素来时,先看它是不是新的最大值。如果是,之前所有的 exp\exp 都要重新缩放;如果不是,只需要把新元素的贡献加到分母上。

Online Softmax 数值演示

用输入 x=[2,5,3,8]x = [2, 5, 3, 8] 手动算一遍,验证公式 的正确性。

标准 Softmax 验证

先算标准 softmax 作为对照:

ixix_iexp(xi)\exp(x_i)softmax
027.3890.0023
15148.4130.0471
2320.0860.0064
382980.9580.9442

分母 l=3156.846l = 3156.846,最终 softmax =[0.0023,0.0471,0.0064,0.9442]= [0.0023, 0.0471, 0.0064, 0.9442]

Online Softmax 逐步计算

初始化:m=m = -\infty, l=0l = 0

Stepxxmm (max)Δ=moldmnew\Delta = m_{old} - m_{new}lexp(Δ)l \cdot \exp(\Delta)exp(xm)\exp(x - m)lnewl_{new}
122-011
2552 - 5 = -31e3=0.04981 \cdot e^{-3} = 0.049811.0498
33501.0498e0=1.04981.0498 \cdot e^0 = 1.0498e2=0.1353e^{-2} = 0.13531.1851
4885 - 8 = -31.1851e3=0.05901.1851 \cdot e^{-3} = 0.059011.0590

最终状态:m=8m = 8, l=1.0590l = 1.0590

验证正确性

用最终状态算每个元素的 softmax:

ixix_iexp(xim)\exp(x_i - m)softmax =exp(xim)/l= \exp(x_i - m) / l
02e6=0.0025e^{-6} = 0.00250.0025/1.0590=0.00240.0025 / 1.0590 = 0.0024
15e3=0.0498e^{-3} = 0.04980.0498/1.0590=0.04700.0498 / 1.0590 = 0.0470
23e5=0.0067e^{-5} = 0.00670.0067/1.0590=0.00640.0067 / 1.0590 = 0.0064
38e0=1e^{0} = 11/1.0590=0.94431 / 1.0590 = 0.9443

和标准 softmax 结果一致(误差来自四舍五入)。

Matrix Tiling 完整示例

用 4×4 的矩阵演示完整的 FlashAttention2 计算流程。

输入矩阵

Q=[12345678],K=[56789101112],V=[910111213141516]Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \\ 7 & 8 \end{bmatrix}, \quad K = \begin{bmatrix} 5 & 6 \\ 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix}, \quad V = \begin{bmatrix} 9 & 10 \\ 11 & 12 \\ 13 & 14 \\ 15 & 16 \end{bmatrix}

d=2d = 2N=4N = 4,分块大小 Br=2B_r = 2, Bc=2B_c = 2

分块策略

计算 O0O_0(query block 0)

初始化:O0=0O_0 = \mathbf{0}, m=m = -\infty, l=0l = \mathbf{0}

第 1 轮:加载 K0,V0K_0, V_0

K0T=[5768],V0=[9101112]K_0^T = \begin{bmatrix} 5 & 7 \\ 6 & 8 \end{bmatrix}, \quad V_0 = \begin{bmatrix} 9 & 10 \\ 11 & 12 \end{bmatrix}

计算 S0=Q0K0TS_0 = Q_0 K_0^T

S0=[1234][5768]=[15+2617+2835+4637+48]=[17233953]S_0 = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 5 & 7 \\ 6 & 8 \end{bmatrix} = \begin{bmatrix} 1\cdot5 + 2\cdot6 & 1\cdot7 + 2\cdot8 \\ 3\cdot5 + 4\cdot6 & 3\cdot7 + 4\cdot8 \end{bmatrix} = \begin{bmatrix} 17 & 23 \\ 39 & 53 \end{bmatrix}

Online softmax + 累加:

iS0[i]S_{0}[i]mmllO0[i]=softmaxVO_0[i] = \sum \text{softmax} \cdot V
01717e0=1e^0 = 11[9,10]=[9,10]1 \cdot [9, 10] = [9, 10]
123231e6+e0=0.0025+1=1.00251 \cdot e^{-6} + e^0 = 0.0025 + 1 = 1.0025e61.0025[9,10]+11.0025[11,12]=0.0025[9,10]+0.9975[11,12]=[10.995,11.995]\begin{aligned} &\frac{e^{-6}}{1.0025} \cdot [9, 10] + \frac{1}{1.0025} \cdot [11, 12] \\ &= 0.0025 \cdot [9, 10] + 0.9975 \cdot [11, 12] \\ &= [10.995, 11.995] \end{aligned}

第 1 轮后:m=[17,23]m = [17, 23], l=[1,1.0025]l = [1, 1.0025], O0=[[9,10],[10.995,11.995]]O_0 = [[9, 10], [10.995, 11.995]]

第 2 轮:加载 K1,V1K_1, V_1

K1T=[9111012],V1=[13141516]K_1^T = \begin{bmatrix} 9 & 11 \\ 10 & 12 \end{bmatrix}, \quad V_1 = \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}

计算 S1=Q0K1TS_1 = Q_0 K_1^T

S1=[1234][9111012]=[29356781]S_1 = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 9 & 11 \\ 10 & 12 \end{bmatrix} = \begin{bmatrix} 29 & 35 \\ 67 & 81 \end{bmatrix}

Online softmax + 累加:

iS1[i]S_1[i]moldm_{old}mnewm_{new}Δ=moldmnew\Delta = m_{old} - m_{new}lnew=loldeΔ+eS1mnewl_{new} = l_{old} \cdot e^\Delta + e^{S_1 - m_{new}}O0O_0 更新
0291729-121e12+e0=6.1×106+111 \cdot e^{-12} + e^0 = 6.1 \times 10^{-6} + 1 \approx 1O0[0]loldeΔlnewO0[0]+eS1mnewlnewV1[0]6.1×106[9,10]+1[13,14][13,14]\begin{aligned} &O_0[0] \leftarrow \frac{l_{old} \cdot e^\Delta}{l_{new}} \cdot O_0[0] + \frac{e^{S_1 - m_{new}}}{l_{new}} \cdot V_1[0] \\ &\approx 6.1 \times 10^{-6} \cdot [9, 10] + 1 \cdot [13, 14] \\ &\approx [13, 14] \end{aligned}
1812381-581.0025e58+e011.0025 \cdot e^{-58} + e^0 \approx 1O0[1]loldeΔlnewO0[1]+eS1mnewlnewV1[1]1.00255.0×1026[10.995,11.995]+1[15,16][15,16]\begin{aligned} &O_0[1] \leftarrow \frac{l_{old} \cdot e^\Delta}{l_{new}} \cdot O_0[1] + \frac{e^{S_1 - m_{new}}}{l_{new}} \cdot V_1[1] \\ &\approx 1.0025 \cdot 5.0 \times 10^{-26} \cdot [10.995, 11.995] + 1 \cdot [15, 16] \\ &\approx [15, 16] \end{aligned}

最终 O0=[13141516]O_0 = \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}

计算 O1O_1(query block 1)

初始化:O1=0O_1 = \mathbf{0}, m=m = -\infty, l=0l = \mathbf{0}

第 1 轮:加载 K0,V0K_0, V_0

Q1=[5678]Q_1 = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}

计算 S0=Q1K0TS_0 = Q_1 K_0^T

S0=[5678][5768]=[618383113]S_0 = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \begin{bmatrix} 5 & 7 \\ 6 & 8 \end{bmatrix} = \begin{bmatrix} 61 & 83 \\ 83 & 113 \end{bmatrix}

Online softmax + 累加:

iS0[i]S_{0}[i]mmllO1[i]O_1[i]
061611[9,10][9, 10]
18383e22+11e^{-22} + 1 \approx 1[11,12][11, 12]

第 2 轮:加载 K1,V1K_1, V_1

计算 S1=Q1K1TS_1 = Q_1 K_1^T

S1=[5678][9111012]=[105127143173]S_1 = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \begin{bmatrix} 9 & 11 \\ 10 & 12 \end{bmatrix} = \begin{bmatrix} 105 & 127 \\ 143 & 173 \end{bmatrix}

Online softmax + 累加:

iS1[i]S_1[i]moldm_{old}mnewm_{new}lnewl_{new}O1[i]O_1[i]
010561105e44+11e^{-44} + 1 \approx 1[13,14][13, 14]
117383173e90+11e^{-90} + 1 \approx 1[15,16][15, 16]

最终 O1=[13141516]O_1 = \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}

最终输出

O=[O0O1]=[1314151613141516]O = \begin{bmatrix} O_0 \\ O_1 \end{bmatrix} = \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 13 & 14 \\ 15 & 16 \end{bmatrix}

这个结果看似”简单”是因为我们的数值设计让后面的 SS 值远大于前面的,所以前面的贡献被新来的大值”掩盖”了。在实际场景中,QQKKVV 的值分布更均匀,每个块的贡献都会被正确累加。

FlashAttention2 vs v1 性能对比

FlashAttention2 在 v1 基础上做了进一步优化:

特性FlashAttention v1FlashAttention2
并行策略GEMM + softmax算子融合 + 线程块并行
工作划分单线程块多线程块并行处理不同 query block
A100 (4096 长度)1.0x 基线~2x 加速
显存占用O(n)O(n)
数值精度BF16/FP32BF16/FP32

FlashAttention2 的核心改进是工作划分:把不同的 query block 分给不同的线程块并行处理,充分利用 GPU 的 SM 数量[1]。v1 只有一个线程块在串行处理所有 query block,资源利用率低。

代码示例:Hugging Face 集成

现代大模型框架已经内置 FlashAttention2 支持:

# 方法 1:自动启用(推荐)
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
# 方法 2:手动安装 flash-attn 包
# pip install flash-attn

from flash_attn import flash_attn_func
import torch

q = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")

# 调用 FlashAttention2
output = flash_attn_func(
    q, k, v,
    causal=True,  # 因果掩码(用于 decoder)
    softmax_scale=1.0 / (128 ** 0.5)
)

关键参数:

Dao-AILab/flash-attention 提供了完整的 CUDA 实现,支持 FlashAttention2[3]

FlashAttention2 的局限

FlashAttention2 不是银弹。它在某些场景下可能不合适:

  1. 短序列:当 N<1024N < 1024 时,分块开销可能大于收益
  2. 非标准注意力:稀疏 attention、局部 attention 需要定制实现
  3. KV Cache 压缩:推理时如果用 KV Cache 量化,需要额外适配

但绝大多数 Transformer 模型的训练和推理,FlashAttention2 都能提供显著的加速和显存节省。



Previous Post
LoRA 与 QLoRA:从低秩适配到双重量化
Next Post
【九】在线RL:GRPO 与 DAPO 的推导与代码实现