参考资料
单张 RTX 3090(24 GiB)上,用 Qwen3-8B 做 32K 长上下文 SFT,峰值 17.18 GiB,GPU 利用率 98.9%。四层优化各自解决一个问题:QLoRA 压模型权重,FA2 消灭 attention 的 O(n²) 显存,gradient checkpointing 压激活,gradient offload 把梯度挪到 CPU。这篇文章拆开讲每一层做了什么,以及一个差点让训练全挂的 CUBLAS bug。
这是 TraceForge 系列的第四篇。第五篇讲偏好训练特有的两个问题:completion 边界定义和双卡拆 pair。
不优化的话需要多少显存
先把 Qwen3-8B 在 32K 下的显存需求拆开算。
| 显存组成 | 计算方式 | 大小 |
|---|---|---|
| 模型权重 (bf16) | 8.03B × 2B | 16.4 GiB |
| 优化器状态 (Adam) | 权重 × 2 (m, v) × 2B + 权重副本 | 98.3 GiB |
| 梯度 (bf16) | 与权重同大 | 16.4 GiB |
| Attention 中间结果 | batch × heads × seq² × 2B | 28.8 TiB (32K) |
| 层间激活 | 36 层 × hidden × seq × 2B × 系数 | ~30 GiB |
| Logits | seq × vocab × 2B | 9.5 GiB |
| 合计 | ~170 GiB+ |
光模型权重加优化器状态就要 130 GiB,一张 3090 连模型都装不下。Attention 更是天文数字。必须每一层都做优化。
第一层:QLoRA 压模型和优化器
问题:8B 参数的模型需要 16.4 GiB 权重 + 98.3 GiB 优化器状态。
QLoRA [1] 做两件事。第一,把基座模型量化到 4-bit NF4(带 double quantization),权重从 16.4 GiB 压到 4.6 GiB。第二,冻结基座参数,只在每层注入低秩 LoRA adapter(rank=16)做微调。可训练参数只有 44M(占总参数的 0.53%),优化器状态从 98.3 GiB 降到 0.3 GiB。
model, tokenizer = FastLanguageModel.from_pretrained(
"qwen3-8b-unsloth-bnb-4bit", # 4-bit 预量化权重
max_seq_length=32768,
load_in_4bit=True, # NF4 量化
)
model = FastLanguageModel.get_peft_model(
model, r=16, lora_alpha=16, # LoRA rank=16, 44M 可训练参数
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
效果: 权重 + 优化器从 ~131 GiB 压到 ~5 GiB。模型装进了 3090。
第二层:Flash Attention 2 消灭 O(n²)
问题:标准 attention 需要存储 n×n 的 score 矩阵。32K 序列 = 28.8 TiB。
标准 self-attention 计算 softmax(QK^T / √d) V,中间的 QK^T 是一个 seq × seq 的矩阵。32K 序列下,每个 head 的 score 矩阵就是 32768 × 32768 × 2B = 2 GiB,乘以 batch × num_heads,总量级约 28.8 TiB。
Flash Attention 2 [2] 用 tiling + online softmax 的方式,把整个 attention 计算压缩到 SRAM 里做完,不需要把完整 score 矩阵写回 HBM。显存从 O(n²) 降到 O(n),28.8 TiB 直接消失。
效果: Attention 显存从 TiB 级降到可忽略。
第三层:Gradient Checkpointing 压激活
问题:前向传播时每一层都要保存中间激活,供反向传播使用。36 层 × 32K 序列 ≈ 30 GiB。
Gradient checkpointing [4] 的思路是:前向时只保存少数检查点(比如每隔几层保存一次),反向时从最近的检查点重新计算丢弃的中间激活。用额外一次前向计算换取显存释放。
Unsloth 的实现比标准 PyTorch 的 torch.utils.checkpoint 更激进 [3],在 Qwen3 的每一层内部都做了细粒度 checkpointing。实测 32K 序列下,激活显存从 ~30 GiB 压到不到 15 GiB。
效果: 激活显存减半,代价是训练速度降低约 30%。
第四层:Gradient Offload 把梯度挪到 CPU
问题:即使只训练 LoRA 参数(44M),梯度张量仍占 GPU 显存。
Unsloth 的 gradient offload 把 LoRA 参数的梯度在反向传播后立即移到 CPU 内存,GPU 上只保留当前正在更新的那一层的梯度。对于 rank=16 的 LoRA,梯度总量约 0.09 GiB,offload 后 GPU 上的梯度占用接近 0。
效果: 省的绝对值不大(~0.1 GiB),但在显存接近上限时,每一点都有意义。
四层叠加的效果
| 显存组成 | 优化前 | 优化后 | 用了什么 |
|---|---|---|---|
| 模型权重 | 16.4 GiB | 4.6 GiB | 4-bit QLoRA |
| 优化器状态 | 98.3 GiB | 0.3 GiB | QLoRA(只更新 LoRA 参数) |
| 梯度 | 16.4 GiB | ≈0(CPU) | Gradient Offload |
| Attention | 28.8 TiB | ≈0 | Flash Attention 2 |
| 层间激活 | ~30 GiB | ~15 GiB | Gradient Checkpointing |
| Logits (32K) | 9.5 GiB | 9.5 GiB | 未优化(SFT 全序列算 loss) |
SFT 场景下,整个序列都是 completion,logits 没法省(不像偏好训练可以只对 action span 算 logits)。32K × 151936 vocab × 2B = 9.5 GiB 是一个硬性开销。
实测峰值 17.18 GiB(32K 真实 debug 轨迹),占 3090 总显存的 72%,余量 6.4 GiB。
真实轨迹训练结果
用 SWE-bench 的 3 个 Django 实例(django-11490, django-11749, django-12965)的真实 debug_subagent 轨迹做 SFT,共 25 个训练样本。
| 指标 | 值 |
|---|---|
| 模型 | Qwen3-8B 4-bit QLoRA (unsloth) |
| 序列长度 | 32K(实际 median 18,489 tokens) |
| 训练样本 | 25 |
| 训练步数 | 20(约 3 个 epoch) |
| 训练时间 | 1,453 秒(~24 分钟) |
| 最终 loss | 1.0597 |
| 峰值显存 | 17.18 GiB (allocated) / 18,038 MiB (monitor) |
| GPU 利用率 | 平均 98.9%,峰值 100% |
| 功耗 | 平均 342W,峰值 350W |
| 温度 | 峰值 67°C |
GPU 利用率 98.9% 说明显存和计算都被充分利用。 没有因为 batch_size=1 导致 GPU 空转,gradient accumulation(步长 4)保证了有效 batch size。
一个差点全挂的 CUBLAS Bug
训练脚本写完,第一次跑就崩了:CUBLAS_STATUS_INVALID_VALUE。
现象: torch.mm() 和 torch.matmul() 在 bf16/fp16 下全部报错,但 torch.addmm() 正常。排查了三天,一度以为是显卡硬件问题。
根因: torch 2.10.0 + CUDA 12.8 的 cublasGemmEx API 在 Ampere GPU(SM 8.6,也就是 RTX 3090)上有一个半精度矩阵乘法 bug。addmm 走了另一条代码路径(cublasGemmStridedBatchedEx),所以没事。
修复一行:
torch.backends.cuda.preferred_blas_library("cublaslt")
强制使用 CUBLAS-LT API 代替 cublasGemmEx,问题消失。这个 bug 不会报任何有意义的错误信息,只有一个泛化的 CUBLAS_STATUS_INVALID_VALUE,非常难定位。
小结
四层优化(QLoRA + FA2 + gradient checkpointing + gradient offload)把 Qwen3-8B 的 32K SFT 从理论 170+ GiB 压到单卡 3090 的 17.18 GiB。每一层解决一个具体的显存组成部分,缺一不可。SFT 场景下 logits 是全序列计算,无法进一步压缩。在偏好训练场景下,可以通过定义 completion 边界(只对 action span 算 logits)进一步缩减这部分开销,详见第五篇:偏好训练的显存优化。