参考资料
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Dao-AILab/flash-attention
- Online Softmax by Hand
- How Flash Attention Works
FlashAttention2 把注意力机制的显存复杂度从 O(n²) 降到 O(n),在 A100 GPU 上比 v1 快约 2 倍[1]。关键在于两个优化:把大矩阵切块分块计算(tiling),和在线更新 softmax 的分母和最大值(online softmax)。用 4×4 的矩阵算一遍,就能看懂每一步中间状态的更新。
标准 Attention 的显存瓶颈
标准 attention 的计算流程是:先算完整注意力矩阵 S=QKT,存起来,再算 softmax,最后乘以 V。对于一个序列长度为 N、hidden 维度为 d 的 attention 层[2]:
- S 矩阵大小:N×N
- 每个元素存 BF16:2 字节
- 总显存:2N2 字节
当 N=4096 时,S 矩阵就要占 2×40962≈33.6 MB。一个 Transformer block 有多个 attention 层,加上中间激活值,总显存很快爆炸。
核心问题是:整个 S 矩阵必须一次性算完、存起来,后面才能用。但计算 S 的每个元素其实只需要 Q 的某一行和 K 的某一列,完全可以边算边丢弃中间结果。
GPU 的内存层级
GPU 的存储不是一层的,而是分了几级:
| 层级 | 容量 | 带宽 | 延迟 | 作用 |
|---|
| HBM | 80 GB (A100) | ~2 TB/s | 高 | 大容量存储 |
| SRAM | 20 MB/SM | ~19 TB/s | 低 | 快速缓存 |
HBM 是”大而慢”,SRAM 是”小而快”。标准 attention 把中间结果都写回 HBM,频繁的数据搬运成了瓶颈。FlashAttention2 的核心思想是把数据搬运降到最少:把 Q、K、V 切成小块,每次只把需要的一块搬到 SRAM,算完再写回 HBM。
Tiling:分块计算
假设 Q,K,V 都是 N×d 的矩阵,把它们切成 Bc×Br 的小块:
- Q 按行分块:每块 Br 行
- K,V 按行分块:每块 Bc 行
计算时,对于每个 query block Qi,依次加载每个 key block Kj 和 value block Vj,在 SRAM 里计算局部 attention:
Oi←Oi+softmax(QiKjT/d)Vj
累加完所有 Kj,Vj 后,Oi 就是完整的输出。
关键在于:每个 Kj 只加载一次,算完就丢,不会占用 O(n²) 显存。SRAM 只需要同时存 Qi、Kj、Vj 和中间结果,容量就够了。
Online Softmax:递推更新
标准 softmax 先算完整分数,再找最大值和求和分母:
softmax(x)i=∑jexp(xj)exp(xi)
但在线场景下数据是一批批来的,没法事先知道最大值和分母。FlashAttention2 用递推公式更新:
假设已经处理了前 t 个元素,保存了:
- mt:当前最大值
- lt:当前分母(exp 的和)
来了第 t+1 个元素 xt+1,更新公式:
mt+1lt+1=max(mt,xt+1)=lt⋅exp(mt−mt+1)+exp(xt+1−mt+1)
输出时,第 i 个元素的 softmax 值是:
softmax(x)i=lTexp(xi−mT)
其中 T 是总元素数。
直观理解:新元素来时,先看它是不是新的最大值。如果是,之前所有的 exp 都要重新缩放;如果不是,只需要把新元素的贡献加到分母上。
Online Softmax 数值演示
用输入 x=[2,5,3,8] 手动算一遍,验证公式 的正确性。
标准 Softmax 验证
先算标准 softmax 作为对照:
| i | xi | exp(xi) | softmax |
|---|
| 0 | 2 | 7.389 | 0.0023 |
| 1 | 5 | 148.413 | 0.0471 |
| 2 | 3 | 20.086 | 0.0064 |
| 3 | 8 | 2980.958 | 0.9442 |
分母 l=3156.846,最终 softmax =[0.0023,0.0471,0.0064,0.9442]。
Online Softmax 逐步计算
初始化:m=−∞, l=0
| Step | x | m (max) | Δ=mold−mnew | l⋅exp(Δ) | exp(x−m) | lnew |
|---|
| 1 | 2 | 2 | - | 0 | 1 | 1 |
| 2 | 5 | 5 | 2 - 5 = -3 | 1⋅e−3=0.0498 | 1 | 1.0498 |
| 3 | 3 | 5 | 0 | 1.0498⋅e0=1.0498 | e−2=0.1353 | 1.1851 |
| 4 | 8 | 8 | 5 - 8 = -3 | 1.1851⋅e−3=0.0590 | 1 | 1.0590 |
最终状态:m=8, l=1.0590
验证正确性
用最终状态算每个元素的 softmax:
| i | xi | exp(xi−m) | softmax =exp(xi−m)/l |
|---|
| 0 | 2 | e−6=0.0025 | 0.0025/1.0590=0.0024 |
| 1 | 5 | e−3=0.0498 | 0.0498/1.0590=0.0470 |
| 2 | 3 | e−5=0.0067 | 0.0067/1.0590=0.0064 |
| 3 | 8 | e0=1 | 1/1.0590=0.9443 |
和标准 softmax 结果一致(误差来自四舍五入)。
Matrix Tiling 完整示例
用 4×4 的矩阵演示完整的 FlashAttention2 计算流程。
输入矩阵
Q=13572468,K=57911681012,V=911131510121416
设 d=2,N=4,分块大小 Br=2, Bc=2。
分块策略
- Q 分成 2 个行块:Q0 (行 0-1), Q1 (行 2-3)
- K 分成 2 个行块:K0 (行 0-1), K1 (行 2-3)
- V 分成 2 个行块:V0 (行 0-1), V1 (行 2-3)
计算 O0(query block 0)
初始化:O0=0, m=−∞, l=0
第 1 轮:加载 K0,V0
K0T=[5678],V0=[9111012]
计算 S0=Q0K0T:
S0=[1324][5678]=[1⋅5+2⋅63⋅5+4⋅61⋅7+2⋅83⋅7+4⋅8]=[17392353]
Online softmax + 累加:
| i | S0[i] | m | l | O0[i]=∑softmax⋅V |
|---|
| 0 | 17 | 17 | e0=1 | 1⋅[9,10]=[9,10] |
| 1 | 23 | 23 | 1⋅e−6+e0=0.0025+1=1.0025 | 1.0025e−6⋅[9,10]+1.00251⋅[11,12]=0.0025⋅[9,10]+0.9975⋅[11,12]=[10.995,11.995] |
第 1 轮后:m=[17,23], l=[1,1.0025], O0=[[9,10],[10.995,11.995]]
第 2 轮:加载 K1,V1
K1T=[9101112],V1=[13151416]
计算 S1=Q0K1T:
S1=[1324][9101112]=[29673581]
Online softmax + 累加:
| i | S1[i] | mold | mnew | Δ=mold−mnew | lnew=lold⋅eΔ+eS1−mnew | O0 更新 |
|---|
| 0 | 29 | 17 | 29 | -12 | 1⋅e−12+e0=6.1×10−6+1≈1 | O0[0]←lnewlold⋅eΔ⋅O0[0]+lneweS1−mnew⋅V1[0]≈6.1×10−6⋅[9,10]+1⋅[13,14]≈[13,14] |
| 1 | 81 | 23 | 81 | -58 | 1.0025⋅e−58+e0≈1 | O0[1]←lnewlold⋅eΔ⋅O0[1]+lneweS1−mnew⋅V1[1]≈1.0025⋅5.0×10−26⋅[10.995,11.995]+1⋅[15,16]≈[15,16] |
最终 O0=[13151416]
计算 O1(query block 1)
初始化:O1=0, m=−∞, l=0
第 1 轮:加载 K0,V0
Q1=[5768]
计算 S0=Q1K0T:
S0=[5768][5678]=[618383113]
Online softmax + 累加:
| i | S0[i] | m | l | O1[i] |
|---|
| 0 | 61 | 61 | 1 | [9,10] |
| 1 | 83 | 83 | e−22+1≈1 | [11,12] |
第 2 轮:加载 K1,V1
计算 S1=Q1K1T:
S1=[5768][9101112]=[105143127173]
Online softmax + 累加:
| i | S1[i] | mold | mnew | lnew | O1[i] |
|---|
| 0 | 105 | 61 | 105 | e−44+1≈1 | [13,14] |
| 1 | 173 | 83 | 173 | e−90+1≈1 | [15,16] |
最终 O1=[13151416]
最终输出
O=[O0O1]=1315131514161416
这个结果看似”简单”是因为我们的数值设计让后面的 S 值远大于前面的,所以前面的贡献被新来的大值”掩盖”了。在实际场景中,Q、K、V 的值分布更均匀,每个块的贡献都会被正确累加。
FlashAttention2 vs v1 性能对比
FlashAttention2 在 v1 基础上做了进一步优化:
| 特性 | FlashAttention v1 | FlashAttention2 |
|---|
| 并行策略 | GEMM + softmax | 算子融合 + 线程块并行 |
| 工作划分 | 单线程块 | 多线程块并行处理不同 query block |
| A100 (4096 长度) | 1.0x 基线 | ~2x 加速 |
| 显存占用 | O(n) | O(n) |
| 数值精度 | BF16/FP32 | BF16/FP32 |
FlashAttention2 的核心改进是工作划分:把不同的 query block 分给不同的线程块并行处理,充分利用 GPU 的 SM 数量[1]。v1 只有一个线程块在串行处理所有 query block,资源利用率低。
代码示例:Hugging Face 集成
现代大模型框架已经内置 FlashAttention2 支持:
# 方法 1:自动启用(推荐)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
# 方法 2:手动安装 flash-attn 包
# pip install flash-attn
from flash_attn import flash_attn_func
import torch
q = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 128, 4096, 128, dtype=torch.bfloat16, device="cuda")
# 调用 FlashAttention2
output = flash_attn_func(
q, k, v,
causal=True, # 因果掩码(用于 decoder)
softmax_scale=1.0 / (128 ** 0.5)
)
关键参数:
causal=True:启用因果掩码(用于 decoder-only 模型)
softmax_scale:手动指定缩放因子,默认 1/dk
- 不需要显式传
dropout,推理时会自动禁用
Dao-AILab/flash-attention 提供了完整的 CUDA 实现,支持 FlashAttention2[3]。
FlashAttention2 的局限
FlashAttention2 不是银弹。它在某些场景下可能不合适:
- 短序列:当 N<1024 时,分块开销可能大于收益
- 非标准注意力:稀疏 attention、局部 attention 需要定制实现
- KV Cache 压缩:推理时如果用 KV Cache 量化,需要额外适配
但绝大多数 Transformer 模型的训练和推理,FlashAttention2 都能提供显著的加速和显存节省。