参考资料
- Attention Is All You Need - Vaswani et al., 2017
- Fast Transformer Decoding: One Write-Head is All You Need - Shazeer et al., 2019
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints - Ainslie et al., 2023
- DeepSeek-V2: Technical Report - DeepSeek-AI, 2024
GQA 通过分组共享 KV,在保持精度的同时将单层 KV Cache 从 32MB 压缩到 4MB;MLA 进一步用低秩投影压缩 KV Cache,DeepSeek-V2 实现了 93.3% 的显存缩减(64X)[4]。从 MHA 到 MLA,四种 Attention 机制代表了 LLM 推理优化的完整演进路径。
标准 Multi-Head Attention (MHA)
MHA 是原始 Transformer 的核心组件[1]。每个头独立计算 Q、K、V 三个投影矩阵,注意力输出是各头的拼接。
标准 Attention 计算公式:
其中 ,, 是头数, 是每头维度。
KV Cache 显存分析
推理时需要缓存每层的 K 和 V 用于自回归生成。单层 KV Cache 显存(BF16,每元素 2 字节):
两个 2 分别表示 K/V 两个矩阵和 BF16 的字节数。
以 Llama-3-8B 单层为例:seq_len=4096, num_heads=32, d_kv=128
小矩阵手动计算
用 4×2 矩阵演示 MHA 的 Attention 计算。设 ,,seq_len=4:
计算 :
计算 (每行独立):
这就是 MHA 的核心计算流程。每个头的 K、V 都需要完整存储,所以显存占用与头数线性相关。
Multi-Query Attention (MQA)
MQA 由 Shazeer 在 2019 年提出[2]。核心思想很简单:所有 query head 共享同一套 KV。
MHA 中,num_kv_heads = num_heads;MQA 中,num_kv_heads = 1。
显存节省
单层 KV Cache 变成:
同样的配置下,MQA 的 KV Cache 是 MHA 的 倍。Llama-3-8B 从 32MB 降到 1MB。
| 配置 | num_heads | num_kv_heads | 单层 KV Cache (MB) | 相对 MHA |
|---|---|---|---|---|
| MHA | 32 | 32 | 32 | 100% |
| MQA | 32 | 1 | 1 | 3% |
精度损失
MQA 的代价是表达能力下降。Shazeer 论文[2]指出,MQA 在某些任务上质量略有下降,但推理速度提升显著。
本质上,MQA 限制了模型学习多样化的注意力模式,因为所有头被迫使用相同的 K/V。这在参数量大、头数多的模型上影响更明显。
Grouped-Query Attention (GQA)
GQA 是 Ainslie 等人在 2023 年提出的折中方案[3]。核心思想:num_kv_heads 介于 1 和 num_heads 之间,多个 query head 共享一组 KV。
分组策略
设 num_heads = 32,num_kv_heads = 8,则每 4 个 query head 共享同一组 KV:
每个 group 内的 query 共享 K/V,但不同 group 之间独立。
显存对比
| 配置 | num_heads | num_kv_heads | 单层 KV Cache (MB) | 相对 MHA |
|---|---|---|---|---|
| MHA | 32 | 32 | 32 | 100% |
| MQA | 32 | 1 | 1 | 3% |
| GQA-8 | 32 | 8 | 8 | 25% |
| GQA-4 | 32 | 4 | 4 | 12.5% |
GQA-8(Llama-3-8B 配置)的 KV Cache 是 MHA 的 1/4,同时保留了 8 组独立的注意力模式,精度损失远小于 MQA。
Llama 配置
Hugging Face Transformers 中配置 GQA 只需设置 num_key_value_heads:
from transformers import LlamaConfig
# MHA(早期 Llama)
config_mha = LlamaConfig(
num_attention_heads=32,
num_key_value_heads=32,
)
# GQA(Llama-3-8B)
config_gqa = LlamaConfig(
num_attention_heads=32,
num_key_value_heads=8, # 每 4 个 query head 共享一套 KV
)attention_config.py
Ainslie 论文[3]还提出了一种”uptraining”方法,可以把已有的 MHA checkpoint 转换为 GQA,无需从头训练。
Multi-Head Latent Attention (MLA)
MLA 是 DeepSeek-V2 提出的创新方案[4]。核心思想:用低秩投影把 KV 压缩到 latent space,推理时再上投影回原维度。
低秩投影原理
MLA 不直接存储 K 和 V,而是存储压缩后的 latent 向量。设原始 KV 为 ,压缩后:
其中 是下投影矩阵, 是压缩后的维度。
推理时需要上投影:
是上投影矩阵,每步动态计算。
显存节省
DeepSeek-V2 报告[4]实现了 93.3% 的 KV Cache 缩减,相当于 64X 压缩。
压缩后的 KV Cache 大小:
假设 (压缩维度),相比 MHA 的 :
即使考虑上投影矩阵的额外开销,整体显存占用仍显著低于 GQA。
DeepSeek-V2 配置
MLA 的实现更复杂,涉及多个投影矩阵。以 DeepSeek-V2-Lite 为例:
# MLA 伪代码示意
k_compressed = k_proj(x) @ c_k_down # 下投影
k_latent = k_compressed @ u_k # 上投影到 seq_len 维度
关键在于 和 的设计,前者是可学习的静态矩阵,后者是位置相关的动态上投影。
综合对比与选型指南
四种机制的完整对比:
| 机制 | num_kv_heads | 单层 KV (MB) | 相对 MHA | 精度损失 | 推理速度 | 代表模型 |
|---|---|---|---|---|---|---|
| MHA | = num_heads | 32 | 100% | 无 | 基线 | 原始 Transformer, BERT |
| MQA | = 1 | 1 | 3% | 明显 | 快 | PaLM, T5 |
| GQA-8 | num_heads/4 | 8 | 25% | 轻微 | 接近 MQA | Llama-2/3, Gemma |
| GQA-4 | num_heads/8 | 4 | 12.5% | 中等 | 接近 MQA | 可配置 |
| MLA | 低秩投影 | ~2-4 | 6-12% | 中等 | 中等 | DeepSeek-V2 |
不同场景的推荐配置
| 场景 | 序列长度 | 推荐机制 | 理由 |
|---|---|---|---|
| 短上下文训练 (< 4K) | < 4096 | MHA | 精度优先,显存不是瓶颈 |
| 中等长度 (4K-8K) | 4096-8192 | GQA-8 | 平衡精度和显存,Llama-3 已验证 |
| 长上下文推理 (8K-32K) | 8192-32768 | GQA-4 或 MLA | 显存成为主要瓶颈 |
| 超长上下文 (32K+) | > 32768 | MLA | 需要极致压缩 |
Hugging Face 配置代码
完整的显存计算函数:
def calc_kv_cache_memory(seq_len, num_kv_heads, d_kv, num_layers=1, dtype_bytes=2):
"""计算 KV Cache 显存占用(字节)
Args:
seq_len: 序列长度
num_kv_heads: KV 头数(MHA=num_heads, MQA=1, GQA=分组数)
d_kv: 每头维度
num_layers: 层数(默认单层)
dtype_bytes: 每元素字节数(BF16/FP16=2, FP32=4)
Returns:
显存占用(MB)
"""
memory_bytes = num_layers * seq_len * num_kv_heads * d_kv * 2 * dtype_bytes
return memory_bytes / (1024 * 1024)
# 示例:Llama-3-8B 单层,4096 长度
memory_mha = calc_kv_cache_memory(4096, 32, 128) # 32 MB
memory_gqa = calc_kv_cache_memory(4096, 8, 128) # 8 MB
memory_mqa = calc_kv_cache_memory(4096, 1, 128) # 1 MBkv_cache_calc.py
实际推理时,KV Cache 需要乘以层数。Llama-3-8B 有 32 层:
# 完整模型的 KV Cache
num_layers = 32 # Llama-3-8B
full_mha = calc_kv_cache_memory(4096, 32, 128, num_layers) # 1024 MB
full_gqa = calc_kv_cache_memory(4096, 8, 128, num_layers) # 256 MB
full_mqa = calc_kv_cache_memory(4096, 1, 128, num_layers) # 32 MBfull_model_calc.py
GQA 把 1GB 的 KV Cache 压到 256MB,单张 24GB 显卡就能跑 32K 上下文;MQA 进一步压到 32MB,但精度损失较大。
小结
从 MHA 到 MLA 的演进,核心矛盾是精度 vs 显存。MHA 是精度基准,MQA 极致压缩但损失质量,GQA 是实用平衡点,MLA 则是针对长上下文的极致优化方案。选择哪种机制,取决于你的场景:训练用 MHA,推理用 GQA,超长上下文考虑 MLA。