Skip to content
传衡博客
返回

MHA vs MQA vs GQA vs MLA:四种 Attention 机制显存与性能全对比

参考资料
  1. Attention Is All You Need - Vaswani et al., 2017
  2. Fast Transformer Decoding: One Write-Head is All You Need - Shazeer et al., 2019
  3. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints - Ainslie et al., 2023
  4. DeepSeek-V2: Technical Report - DeepSeek-AI, 2024

GQA 通过分组共享 KV,在保持精度的同时将单层 KV Cache 从 32MB 压缩到 4MB;MLA 进一步用低秩投影压缩 KV Cache,DeepSeek-V2 实现了 93.3% 的显存缩减(64X)[4]。从 MHA 到 MLA,四种 Attention 机制代表了 LLM 推理优化的完整演进路径。

标准 Multi-Head Attention (MHA)

MHA 是原始 Transformer 的核心组件[1]。每个头独立计算 Q、K、V 三个投影矩阵,注意力输出是各头的拼接。

标准 Attention 计算公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中 QRn×hdkQ \in \mathbb{R}^{n \times h \cdot d_k}K,VRn×hdkK, V \in \mathbb{R}^{n \times h \cdot d_k}hh 是头数,dkd_k 是每头维度。

KV Cache 显存分析

推理时需要缓存每层的 K 和 V 用于自回归生成。单层 KV Cache 显存(BF16,每元素 2 字节):

MemoryKV=seq_len×nkv_heads×dkv×2×2 bytes\text{Memory}_{\text{KV}} = \text{seq\_len} \times \text{nkv\_heads} \times d_{kv} \times 2 \times 2 \text{ bytes}

两个 2 分别表示 K/V 两个矩阵和 BF16 的字节数。

以 Llama-3-8B 单层为例:seq_len=4096, num_heads=32, d_kv=128

4096×32×128×2×2=32 MB4096 \times 32 \times 128 \times 2 \times 2 = 32 \text{ MB}

小矩阵手动计算

用 4×2 矩阵演示 MHA 的 Attention 计算。设 dk=2d_k=2h=2h=2,seq_len=4:

Q=[12345678],K=[23456789],V=[1011121314151617]Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \\ 7 & 8 \end{bmatrix}, \quad K = \begin{bmatrix} 2 & 3 \\ 4 & 5 \\ 6 & 7 \\ 8 & 9 \end{bmatrix}, \quad V = \begin{bmatrix} 10 & 11 \\ 12 & 13 \\ 14 & 15 \\ 16 & 17 \end{bmatrix}

计算 S=QKT/dk=QKT/2S = QK^T / \sqrt{d_k} = QK^T / \sqrt{2}

S=12[8142026182634422838465838505870][5.669.9014.1418.3812.7318.3824.0429.7019.8026.8732.5341.0226.8735.3641.0249.50]S = \frac{1}{\sqrt{2}} \begin{bmatrix} 8 & 14 & 20 & 26 \\ 18 & 26 & 34 & 42 \\ 28 & 38 & 46 & 58 \\ 38 & 50 & 58 & 70 \end{bmatrix} \approx \begin{bmatrix} 5.66 & 9.90 & 14.14 & 18.38 \\ 12.73 & 18.38 & 24.04 & 29.70 \\ 19.80 & 26.87 & 32.53 & 41.02 \\ 26.87 & 35.36 & 41.02 & 49.50 \end{bmatrix}

计算 softmax(S)\text{softmax}(S)(每行独立):

O=softmax(S)×VO = \text{softmax}(S) \times V

这就是 MHA 的核心计算流程。每个头的 K、V 都需要完整存储,所以显存占用与头数线性相关。

Multi-Query Attention (MQA)

MQA 由 Shazeer 在 2019 年提出[2]。核心思想很简单:所有 query head 共享同一套 KV。

MHA 中,num_kv_heads = num_heads;MQA 中,num_kv_heads = 1。

显存节省

单层 KV Cache 变成:

MemoryKVMQA=seq_len×1×dkv×2×2\text{Memory}_{\text{KV}}^{\text{MQA}} = \text{seq\_len} \times 1 \times d_{kv} \times 2 \times 2

同样的配置下,MQA 的 KV Cache 是 MHA 的 1/h1/h 倍。Llama-3-8B 从 32MB 降到 1MB。

配置num_headsnum_kv_heads单层 KV Cache (MB)相对 MHA
MHA323232100%
MQA32113%

精度损失

MQA 的代价是表达能力下降。Shazeer 论文[2]指出,MQA 在某些任务上质量略有下降,但推理速度提升显著。

本质上,MQA 限制了模型学习多样化的注意力模式,因为所有头被迫使用相同的 K/V。这在参数量大、头数多的模型上影响更明显。

Grouped-Query Attention (GQA)

GQA 是 Ainslie 等人在 2023 年提出的折中方案[3]。核心思想:num_kv_heads 介于 1 和 num_heads 之间,多个 query head 共享一组 KV。

分组策略

设 num_heads = 32,num_kv_heads = 8,则每 4 个 query head 共享同一组 KV:

group_size=num_headsnum_kv_heads=328=4\text{group\_size} = \frac{\text{num\_heads}}{\text{num\_kv\_heads}} = \frac{32}{8} = 4

每个 group 内的 query 共享 K/V,但不同 group 之间独立。

显存对比

配置num_headsnum_kv_heads单层 KV Cache (MB)相对 MHA
MHA323232100%
MQA32113%
GQA-8328825%
GQA-4324412.5%

GQA-8(Llama-3-8B 配置)的 KV Cache 是 MHA 的 1/4,同时保留了 8 组独立的注意力模式,精度损失远小于 MQA。

Llama 配置

Hugging Face Transformers 中配置 GQA 只需设置 num_key_value_heads

from transformers import LlamaConfig

# MHA(早期 Llama)
config_mha = LlamaConfig(
    num_attention_heads=32,
    num_key_value_heads=32,
)

# GQA(Llama-3-8B)
config_gqa = LlamaConfig(
    num_attention_heads=32,
    num_key_value_heads=8,   # 每 4 个 query head 共享一套 KV
)attention_config.py

Ainslie 论文[3]还提出了一种”uptraining”方法,可以把已有的 MHA checkpoint 转换为 GQA,无需从头训练。

Multi-Head Latent Attention (MLA)

MLA 是 DeepSeek-V2 提出的创新方案[4]。核心思想:用低秩投影把 KV 压缩到 latent space,推理时再上投影回原维度。

低秩投影原理

MLA 不直接存储 K 和 V,而是存储压缩后的 latent 向量。设原始 KV 为 KfullRn×hdkK_{\text{full}} \in \mathbb{R}^{n \times h \cdot d_k},压缩后:

Kcompressed=KfullcK,cKRhdk×dcK_{\text{compressed}} = K_{\text{full}} \cdot c_K, \quad c_K \in \mathbb{R}^{h \cdot d_k \times d_c}

其中 cKc_K 是下投影矩阵,dchdkd_c \ll h \cdot d_k 是压缩后的维度。

推理时需要上投影:

Kupprojected=KcompresseduK,uKRdc×nK_{\text{upprojected}} = K_{\text{compressed}} \cdot u_K, \quad u_K \in \mathbb{R}^{d_c \times n}

uKu_K 是上投影矩阵,每步动态计算。

显存节省

DeepSeek-V2 报告[4]实现了 93.3% 的 KV Cache 缩减,相当于 64X 压缩。

压缩后的 KV Cache 大小:

MemoryKVMLA=seq_len×dc×2×2\text{Memory}_{\text{KV}}^{\text{MLA}} = \text{seq\_len} \times d_c \times 2 \times 2

假设 dc=32d_c = 32(压缩维度),相比 MHA 的 hdk=32×128=4096h \cdot d_k = 32 \times 128 = 4096

324096=0.78%\frac{32}{4096} = 0.78\%

即使考虑上投影矩阵的额外开销,整体显存占用仍显著低于 GQA。

DeepSeek-V2 配置

MLA 的实现更复杂,涉及多个投影矩阵。以 DeepSeek-V2-Lite 为例:

# MLA 伪代码示意
k_compressed = k_proj(x) @ c_k_down      # 下投影
k_latent = k_compressed @ u_k            # 上投影到 seq_len 维度

关键在于 ckc_kuku_k 的设计,前者是可学习的静态矩阵,后者是位置相关的动态上投影。

综合对比与选型指南

四种机制的完整对比:

机制num_kv_heads单层 KV (MB)相对 MHA精度损失推理速度代表模型
MHA= num_heads32100%基线原始 Transformer, BERT
MQA= 113%明显PaLM, T5
GQA-8num_heads/4825%轻微接近 MQALlama-2/3, Gemma
GQA-4num_heads/8412.5%中等接近 MQA可配置
MLA低秩投影~2-46-12%中等中等DeepSeek-V2

不同场景的推荐配置

场景序列长度推荐机制理由
短上下文训练 (< 4K)< 4096MHA精度优先,显存不是瓶颈
中等长度 (4K-8K)4096-8192GQA-8平衡精度和显存,Llama-3 已验证
长上下文推理 (8K-32K)8192-32768GQA-4 或 MLA显存成为主要瓶颈
超长上下文 (32K+)> 32768MLA需要极致压缩

Hugging Face 配置代码

完整的显存计算函数:

def calc_kv_cache_memory(seq_len, num_kv_heads, d_kv, num_layers=1, dtype_bytes=2):
    """计算 KV Cache 显存占用(字节)

    Args:
        seq_len: 序列长度
        num_kv_heads: KV 头数(MHA=num_heads, MQA=1, GQA=分组数)
        d_kv: 每头维度
        num_layers: 层数(默认单层)
        dtype_bytes: 每元素字节数(BF16/FP16=2, FP32=4)

    Returns:
        显存占用(MB)
    """
    memory_bytes = num_layers * seq_len * num_kv_heads * d_kv * 2 * dtype_bytes
    return memory_bytes / (1024 * 1024)

# 示例:Llama-3-8B 单层,4096 长度
memory_mha = calc_kv_cache_memory(4096, 32, 128)    # 32 MB
memory_gqa = calc_kv_cache_memory(4096, 8, 128)     # 8 MB
memory_mqa = calc_kv_cache_memory(4096, 1, 128)     # 1 MBkv_cache_calc.py

实际推理时,KV Cache 需要乘以层数。Llama-3-8B 有 32 层:

32 MB/层×32 层=1024 MB(MHA, GQA-8)32 \text{ MB/层} \times 32 \text{ 层} = 1024 \text{ MB} \quad (\text{MHA, GQA-8})
# 完整模型的 KV Cache
num_layers = 32  # Llama-3-8B
full_mha = calc_kv_cache_memory(4096, 32, 128, num_layers)    # 1024 MB
full_gqa = calc_kv_cache_memory(4096, 8, 128, num_layers)     # 256 MB
full_mqa = calc_kv_cache_memory(4096, 1, 128, num_layers)     # 32 MBfull_model_calc.py

GQA 把 1GB 的 KV Cache 压到 256MB,单张 24GB 显卡就能跑 32K 上下文;MQA 进一步压到 32MB,但精度损失较大。

小结

从 MHA 到 MLA 的演进,核心矛盾是精度 vs 显存。MHA 是精度基准,MQA 极致压缩但损失质量,GQA 是实用平衡点,MLA 则是针对长上下文的极致优化方案。选择哪种机制,取决于你的场景:训练用 MHA,推理用 GQA,超长上下文考虑 MLA。



Previous Post
Activation Checkpointing:用时间换显存的艺术
Next Post
LoRA 与 QLoRA:从低秩适配到双重量化