Skip to content
传衡博客
返回

【四】SFT 微调

参考资料
  1. QLoRA: Efficient Finetuning of Quantized LLMs
  2. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  3. Unsloth — Fine-tune Llama 3.1 (70B/8B) and Mistral NeMo (12B) 2-5x faster with 70% less memory
  4. Gradient Checkpointing — Training Deep Nets with Sublinear Memory Cost

单张 RTX 3090(24 GiB)上,用 Qwen3-8B 做 32K 长上下文 SFT,峰值 17.18 GiB,GPU 利用率 98.9%。四层优化各自解决一个问题:QLoRA 压模型权重,FA2 消灭 attention 的 O(n²) 显存,gradient checkpointing 压激活,gradient offload 把梯度挪到 CPU。这篇文章拆开讲每一层做了什么,以及一个差点让训练全挂的 CUBLAS bug。

这是 TraceForge 系列的第四篇。第五篇讲偏好训练特有的两个问题:completion 边界定义和双卡拆 pair。

不优化的话需要多少显存

先把 Qwen3-8B 在 32K 下的显存需求拆开算。

显存组成计算方式大小
模型权重 (bf16)8.03B × 2B16.4 GiB
优化器状态 (Adam)权重 × 2 (m, v) × 2B + 权重副本98.3 GiB
梯度 (bf16)与权重同大16.4 GiB
Attention 中间结果batch × heads × seq² × 2B28.8 TiB (32K)
层间激活36 层 × hidden × seq × 2B × 系数~30 GiB
Logitsseq × vocab × 2B9.5 GiB
合计~170 GiB+

光模型权重加优化器状态就要 130 GiB,一张 3090 连模型都装不下。Attention 更是天文数字。必须每一层都做优化。

第一层:QLoRA 压模型和优化器

问题:8B 参数的模型需要 16.4 GiB 权重 + 98.3 GiB 优化器状态。

QLoRA [1] 做两件事。第一,把基座模型量化到 4-bit NF4(带 double quantization),权重从 16.4 GiB 压到 4.6 GiB。第二,冻结基座参数,只在每层注入低秩 LoRA adapter(rank=16)做微调。可训练参数只有 44M(占总参数的 0.53%),优化器状态从 98.3 GiB 降到 0.3 GiB

model, tokenizer = FastLanguageModel.from_pretrained(
    "qwen3-8b-unsloth-bnb-4bit",     # 4-bit 预量化权重
    max_seq_length=32768,
    load_in_4bit=True,                # NF4 量化
)
model = FastLanguageModel.get_peft_model(
    model, r=16, lora_alpha=16,       # LoRA rank=16, 44M 可训练参数
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
)

效果: 权重 + 优化器从 ~131 GiB 压到 ~5 GiB。模型装进了 3090。

第二层:Flash Attention 2 消灭 O(n²)

问题:标准 attention 需要存储 n×n 的 score 矩阵。32K 序列 = 28.8 TiB。

标准 self-attention 计算 softmax(QK^T / √d) V,中间的 QK^T 是一个 seq × seq 的矩阵。32K 序列下,每个 head 的 score 矩阵就是 32768 × 32768 × 2B = 2 GiB,乘以 batch × num_heads,总量级约 28.8 TiB。

Flash Attention 2 [2] 用 tiling + online softmax 的方式,把整个 attention 计算压缩到 SRAM 里做完,不需要把完整 score 矩阵写回 HBM。显存从 O(n²) 降到 O(n),28.8 TiB 直接消失。

效果: Attention 显存从 TiB 级降到可忽略。

第三层:Gradient Checkpointing 压激活

问题:前向传播时每一层都要保存中间激活,供反向传播使用。36 层 × 32K 序列 ≈ 30 GiB。

Gradient checkpointing [4] 的思路是:前向时只保存少数检查点(比如每隔几层保存一次),反向时从最近的检查点重新计算丢弃的中间激活。用额外一次前向计算换取显存释放。

Unsloth 的实现比标准 PyTorch 的 torch.utils.checkpoint 更激进 [3],在 Qwen3 的每一层内部都做了细粒度 checkpointing。实测 32K 序列下,激活显存从 ~30 GiB 压到不到 15 GiB

效果: 激活显存减半,代价是训练速度降低约 30%。

第四层:Gradient Offload 把梯度挪到 CPU

问题:即使只训练 LoRA 参数(44M),梯度张量仍占 GPU 显存。

Unsloth 的 gradient offload 把 LoRA 参数的梯度在反向传播后立即移到 CPU 内存,GPU 上只保留当前正在更新的那一层的梯度。对于 rank=16 的 LoRA,梯度总量约 0.09 GiB,offload 后 GPU 上的梯度占用接近 0

效果: 省的绝对值不大(~0.1 GiB),但在显存接近上限时,每一点都有意义。

四层叠加的效果

SFT 显存优化前后对比
图 1:32K SFT 各显存组成部分的优化前后对比(dumbbell chart)。Attention 的 28.8 TiB → ≈0 变化太大,未在图中显示。
显存组成优化前优化后用了什么
模型权重16.4 GiB4.6 GiB4-bit QLoRA
优化器状态98.3 GiB0.3 GiBQLoRA(只更新 LoRA 参数)
梯度16.4 GiB≈0(CPU)Gradient Offload
Attention28.8 TiB≈0Flash Attention 2
层间激活~30 GiB~15 GiBGradient Checkpointing
Logits (32K)9.5 GiB9.5 GiB未优化(SFT 全序列算 loss)

SFT 场景下,整个序列都是 completion,logits 没法省(不像偏好训练可以只对 action span 算 logits)。32K × 151936 vocab × 2B = 9.5 GiB 是一个硬性开销。

实测峰值 17.18 GiB(32K 真实 debug 轨迹),占 3090 总显存的 72%,余量 6.4 GiB。

真实轨迹训练结果

用 SWE-bench 的 3 个 Django 实例(django-11490, django-11749, django-12965)的真实 debug_subagent 轨迹做 SFT,共 25 个训练样本。

32K SFT 训练结果
图 2:32K SFT 真实轨迹训练结果摘要。峰值显存 17.18 GiB,GPU 利用率 98.9%,24 分钟完成 20 步训练。
指标
模型Qwen3-8B 4-bit QLoRA (unsloth)
序列长度32K(实际 median 18,489 tokens)
训练样本25
训练步数20(约 3 个 epoch)
训练时间1,453 秒(~24 分钟)
最终 loss1.0597
峰值显存17.18 GiB (allocated) / 18,038 MiB (monitor)
GPU 利用率平均 98.9%,峰值 100%
功耗平均 342W,峰值 350W
温度峰值 67°C

GPU 利用率 98.9% 说明显存和计算都被充分利用。 没有因为 batch_size=1 导致 GPU 空转,gradient accumulation(步长 4)保证了有效 batch size。

一个差点全挂的 CUBLAS Bug

训练脚本写完,第一次跑就崩了:CUBLAS_STATUS_INVALID_VALUE

现象: torch.mm()torch.matmul() 在 bf16/fp16 下全部报错,但 torch.addmm() 正常。排查了三天,一度以为是显卡硬件问题。

根因: torch 2.10.0 + CUDA 12.8 的 cublasGemmEx API 在 Ampere GPU(SM 8.6,也就是 RTX 3090)上有一个半精度矩阵乘法 bug。addmm 走了另一条代码路径(cublasGemmStridedBatchedEx),所以没事。

修复一行:

torch.backends.cuda.preferred_blas_library("cublaslt")

强制使用 CUBLAS-LT API 代替 cublasGemmEx,问题消失。这个 bug 不会报任何有意义的错误信息,只有一个泛化的 CUBLAS_STATUS_INVALID_VALUE,非常难定位。

小结

四层优化(QLoRA + FA2 + gradient checkpointing + gradient offload)把 Qwen3-8B 的 32K SFT 从理论 170+ GiB 压到单卡 3090 的 17.18 GiB。每一层解决一个具体的显存组成部分,缺一不可。SFT 场景下 logits 是全序列计算,无法进一步压缩。在偏好训练场景下,可以通过定义 completion 边界(只对 action span 算 logits)进一步缩减这部分开销,详见第五篇:偏好训练的显存优化



Previous Post
【五】SimPO 训练与 BranchParallel 策略实现
Next Post
【三】动态调用图的实现与压缩