参考资料
- SimPO: Simple Preference Optimization with a Reference-Free Reward
- Unsloth — Fine-tune Llama 3.1 (70B/8B) and Mistral NeMo (12B) 2-5x faster with 70% less memory
- TRL — CPOTrainer / ORPOTrainer 源码
- logits_to_keep — transformers 4.45+ (PR #33005)
- ORPO: Monolithic Preference Optimization without Reference Model
- Unsloth GitHub Issue #1990 — support model parallel training?
两块 RTX 3090(各 24 GiB)上,把 Agent 轨迹的 SimPO 偏好训练从 ORPO 的 24K OOM 推到 80K 跑通,峰值 21.05 GiB/卡。最关键的两步不是什么高级并行技术,而是:搞清楚 80K 序列里只有 4K token 需要算 logits,然后把 chosen/rejected 两条 branch 拆到两张卡上。
这是 TraceForge 系列的第五篇。第四篇讲了 QLoRA、FA2、gradient checkpointing 等基础优化如何让 32K SFT 在单卡 3090 上跑通。本篇在此基础上,解决偏好训练特有的两个问题。
偏好训练为什么比 SFT 贵一倍
SFT 只有一条序列。 给定 prompt + completion,模型做一次前向、一次反向,显存占用就是这一条序列的代价。Qwen3-8B 的 32K SFT 在单卡 3090 上实测 21.3 GiB,用了 4-bit QLoRA + Flash Attention 2 + gradient checkpointing [2],余量舒适。
偏好训练要处理一对序列。 ORPO、SimPO、DPO 这类 pairwise 方法,每个训练样本都包含 chosen 和 rejected 两条 completion。训练器需要同时对两条序列做前向,才能计算偏好 loss。MLP 激活、logits 张量、梯度检查点输入,全部翻倍。
实测边界很清楚。 在同一台机器上,ORPO 单卡能跑 16K,到 24K 就 OOM(差 1.09 GiB)。
但真实数据需要长上下文。 我的场景是训练一个 PDB 调试 Agent,它的轨迹天然很长。9 条真实 debug 轨迹中,5 条超过 32K token,2 条超过 64K,最长的达到 83,720 token。长上下文训练不是在刷数字,而是被任务形态决定的。
走不通的四条路
在找到可行方案之前,我试了四条主流思路,全部失败。
| 方案 | 原理 | 失败原因 |
|---|---|---|
| Ulysses SP=2/4 | 沿序列维度切分 attention | ORPO batch key 与 Ulysses adapter 不对齐,张量没被切分 |
| Arctic ALST | CPU offload + tiled MLP | 16K 下两个变体均 OOM(21.07 vs 21.09 GiB),收益倒挂 |
| Ring CP=2 | Context Parallelism | FSDP2 + PEFT 的 prepare() 阶段崩溃,集成太厚 |
| logits_to_keep(错误版) | 只保留尾部 logits | 连 input_ids 也被截短,模型只看了 3K token |
前三条是”方向对但接不上”。 Ulysses 的 SP 声称启用但张量没变,说明 ORPO 的 batch 构造和 SP adapter 期望的 key 对不上。ALST 在 sdpa 路径上反而更贵。Ring CP 在 FSDP2 + PEFT 的交互层就卡死了。
第四条最隐蔽。 这个 logits_to_keep patch 显存确实降了(从 OOM 降到 52% 占用),但后来发现它不只截了 lm_head 输出,还把 input_ids 和 attention_mask 也裁短了。模型实际只看了尾部 3K token 做 attention,32K 上下文根本没进计算图。这不是显存优化,是悄悄换了一个更短的任务。这个教训后来反过来帮我理解了 logits 的正确用法。
80K 序列里只有 4K token 需要算 logits
这是整件事最关键的认知转变。
一条 Agent 轨迹的结构: 主 Agent 把 question(问题描述)和 test(失败的测试)传给 debug subagent,subagent 多轮交互产生一系列 action(工具调用)和 observation(环境返回)。一条典型轨迹大约 20 轮,最终 prompt 长度从 13K 到 84K 不等。
偏好训练的 score 只应该算在”模型的决策”上。 SimPO 的 avg_logp 表示 completion 部分的平均 log 概率。但”completion”在 Agent 轨迹里到底是什么?question 和 test 是任务输入,observation 是环境的确定性输出,这些都不是模型的决策。只有 action span 才代表模型真正需要学习的行为。
我的三次尝试:
第一版把 question + test 当 prompt,其余全部当 completion。一条 80K 的轨迹,question + test 可能只占 5K,那 “completion” 就有 75K token,每个都需要 lm_head 投影到 151936 维 vocab。单个 branch 的 logits 张量就要 75K × 151936 × 2B ≈ 22 GiB,直接 OOM。
第二版把 observation 也从 completion 里去掉,只保留 agent 的 action span。从实际轨迹统计看,9 条轨迹的 action token 总量只有 16,525,平均每次 API 调用约 86 token。一条轨迹的全部 action span 加起来大约 2K~5K token。logits 张量从 22 GiB 骤降到 4K × 151936 × 2B ≈ 1.2 GiB,直接可行。
第二版成功了。
这不是一个”显存优化技巧”。 logits_to_keep 本身是 transformers 4.45+ 的标准功能 [4],标准 SFT 和偏好训练里早就在用。真正的问题不是”用不用 logits_to_keep”,而是”在 Agent 轨迹里,completion 的边界在哪里”。
# 错误:75K completion → 22 GiB logits → OOM
labels = mark_everything_except_question_as_completion(input_ids)
# 正确:4K action span → 1.2 GiB logits → OK
labels = mark_only_action_spans_as_completion(input_ids)
# 其余位置 labels = -100,前向仍看全文,loss 只落在 action token 上
模型仍然对完整 80K token 做 attention。 logits_to_keep 不截断 input_ids,不改变 attention 计算。模型看到完整的 question、test、所有 observation、所有历史 action,但只在最后的 lm_head 阶段对 action span 的位置做 vocab 投影。上下文信息完整保留,logits 开销降到最低。
从 pair 结构找突破口
定义好 completion 边界后,logits 不再是瓶颈。但偏好训练还有第二个问题:一对序列的 MLP 激活和梯度检查点输入仍然翻倍。
回头看数据就会发现:单条 32K branch 本身是能训的(SFT 21.3 GiB),偏好训练爆显存是因为两条 branch 同时在一张卡上。瓶颈在 pair 结构,不在单条 branch。
为什么不用 Context Parallelism
Context Parallelism(CP)沿序列维度切分,看起来是长上下文的自然选择。但在 2 卡偏好训练这个场景下,它的收益和 branch-parallel 一样,工程代价却高得多。
先算显存。 CP=2 把每条 80K 序列切成两个 40K,但每张卡仍然要处理 chosen + rejected 两条序列。每卡 MLP 激活:2 branches × 40K = 80K token。Branch-parallel 每卡只处理一条完整 80K 序列。每卡 MLP 激活:1 branch × 80K = 80K token。两种方案在 2 卡上的 per-GPU 激活负载完全相同。
再看通信。 CP 需要每一层 attention 都做 KV 交换(ring attention),Qwen3-8B 有 36 层,每步要做 36 × 2 branches = 72 次跨卡通信。Branch-parallel 全程只交换 2 个标量 score + 一次 all-reduce LoRA 梯度。
最后看集成。 Unsloth 截至 2026 年初仍不支持 CP/SP,只提供 DDP [6]。CP 需要 FSDP2 或 DeepSpeed 与 PEFT 的深度集成,而我实测 Ring CP=2 在 FSDP2 + PEFT 的 prepare() 阶段直接崩溃。即使 CP 的集成问题被修复,在 2 卡偏好训练的场景下也不会比 branch-parallel 更省显存。
| 维度 | Context Parallelism (CP=2) | Branch-Parallel |
|---|---|---|
| 每卡 MLP 激活 | 2 branches × 40K = 80K token | 1 branch × 80K = 80K token |
| 每步跨卡通信 | 72 次 KV 交换(36 层 × 2 branches) | 2 个标量 + 1 次 LoRA grad all-reduce |
| 与 unsloth/QLoRA 集成 | 不支持,需 FSDP2 + PEFT | 直接可用,手动 gloo 同步 |
| 实现复杂度 | ring attention + 分布式框架 plumbing | 约 50 行 PyTorch 代码 |
CP 在更多卡(4+)或更长序列(160K+)的场景下有价值,因为它可以和 branch-parallel 正交组合:4 卡 = 2 CP × 2 branch。但在 2 卡 80K 这个约束下,branch-parallel 是更简单、通信更轻、集成更薄的选择。
选择 SimPO 的原因
SimPO 的 loss 函数天然适合拆分。 它的目标是纯对称的 pairwise 偏好 [1]:
其中 和 分别是 chosen 和 rejected 的 completion-only 平均 log 概率。对参数 θ 求导,链式法则给出:
两项分别只依赖各自 branch 的前向计算,可以在不同 GPU 上独立完成。DPO 需要额外的 reference model(4 次前向),ORPO 的 chosen 侧有额外 NLL 项(不对称),SimPO 是三者中拆分代价最低的。
Branch-Parallel 协议
核心只有 5 步:
# rank 0: chosen branch, rank 1: rejected branch
score = forward(branch_data) # 各自前向,互不干扰
scores = all_gather(score) # 交换两个标量
loss = simpo_loss(scores, rank) # 各自用 detach 对方的 score 算 loss
loss.backward() # 各自反向,梯度只流过自己的 branch
all_reduce_lora_grads(model, SUM) # 合并梯度 → optimizer.step()
不用 DDP。 unsloth 每卡独立加载模型,梯度同步用手动 dist.all_reduce(grad, op=SUM),没有框架自动 average 的缩放问题。两卡从同一初始参数出发,收到相同的合并梯度,每步更新后参数保持一致。
用 gloo 绕过 NCCL。 每个 worker 通过 CUDA_VISIBLE_DEVICES 只看到一块 GPU(cuda:0),但 NCCL 用 rank 作 device ordinal,rank 1 会尝试访问不存在的 cuda:1。gloo 走 CPU 通信,对于只交换标量 score + LoRA 梯度的场景,延迟可忽略。
等价性已验证。 短序列固定种子下,branch-parallel 与标准单卡 CPOTrainer 的 loss、梯度余弦相似度 0.9998,最大参数差异 1.5e-6。
爬坡结果:32K 到 80K
从 32K 开始逐步推高序列长度。基础配置统一为:Qwen3-8B + 4-bit QLoRA + unsloth + FA2 + gradient checkpointing + 双卡 branch-parallel + logits_to_keep(只对 action span 算 logits)。
| 序列长度 | 实际 tokens | 步数 | 峰值 GiB/卡 | 余量 GiB | 秒/步 | 状态 |
|---|---|---|---|---|---|---|
| 32K | 32,768 | 1 | 13.61 | 9.95 | 45 | Pass |
| 64K | 65,536 | 10 | 19.01 | 4.55 | 130 | Pass |
| 72K | 73,728 | 5 | 20.31 | 3.25 | 159 | Pass |
| 80K | 81,851 | 1 | 21.05 | 2.51 | 194 | Pass |
| 90K | — | — | OOM | — | — | Fail |
显存增长亚线性。 32K 到 64K 序列翻倍,显存只增长 1.40x(13.61 → 19.01 GiB),说明 unsloth 的 gradient checkpointing 在长序列下效率很高。
80K 是当前实用上限。 81,851 token(77K prompt + 4K completion),峰值 21.05 GiB(占总显存 89.4%),吞吐 422 tokens/s/GPU。
72K 运行的 GPU 监控数据(2 秒采样,195 个样本点,5 步全程无异常):
| 指标 | GPU 0 (chosen) | GPU 1 (rejected) |
|---|---|---|
| 峰值显存 | 22,870 MiB | 22,870 MiB |
| 平均显存 | 22,870 MiB | 22,870 MiB |
| 平均 GPU 利用率 | 99.0% | 98.5% |
| 峰值 GPU 利用率 | 100% | 100% |
| 平均功耗 | 339.5 W | 342.4 W |
| 峰值温度 | 67°C | 67°C |
两卡显存全程钉在峰值,利用率接近 100%,不是 bursty 的假成功。
90K 的 OOM 是碎片问题。 可用碎片 2.59 GiB 大于需要申请的连续块 2.11 GiB,但 PyTorch 分配器找不到足够大的连续空间。first-fail 始终在 MLP SwiGLU(swiglu_fg_kernel)。
各层优化的贡献
80K 偏好训练的显存由六个部分组成。下表列出每个部分在有/无对应优化时的大小。
| 显存组成 | 无优化时 | 优化后 | 用了什么 |
|---|---|---|---|
| 模型权重 | 16.4 GiB (bf16) | 4.6 GiB | 4-bit QLoRA |
| 优化器状态 | 98.3 GiB (Adam bf16) | 0.3 GiB | QLoRA(只更新 LoRA 参数) |
| Attention 中间结果 | 28.8 TiB (O(n²) score 矩阵) | ≈0 | Flash Attention 2 |
| MLP / 层间激活 | 266 GiB | 30 GiB | Gradient Checkpointing (unsloth) |
| 梯度张量 | 16.4 GiB (bf16 全参) | ≈0(CPU) | Gradient Offload (unsloth) |
| Logits 张量 | 46.4 GiB (全量 80K × vocab × 2 branches) | 2.3 GiB | logits_to_keep(只对 action span 投影) |
| 双 branch 叠加 | 以上全部 ×2 / 单卡 | ÷2 / 每卡 | Branch-Parallel 双卡 |
QLoRA、FA2、gradient checkpointing、gradient offload 是标准实践,不是本文的贡献。 任何做长上下文 LoRA 训练的人都会用这些。它们把 baseline 压到一个可工作的起点,但偏好训练还剩两个问题需要单独解决:
第一,logits 该算到哪里。 80K 序列的全量 logits 要 46.4 GiB(两个 branch),但只对 action span 算 logits 就只需要 2.3 GiB。差距来自对 completion 边界的定义:如果把 observation 也当 completion,显存直接爆炸;如果只算 action,轻松通过。
第二,pair 结构怎么拆。 偏好训练必须同时处理 chosen 和 rejected,单卡放不下就拆到两卡。SimPO 的对称 loss 函数让梯度可以精确拆分,两卡只需交换两个标量 score,然后 all-reduce LoRA 梯度。
小结
把 Agent 轨迹的偏好训练推到 80K(77K prompt + 4K action),核心靠两件事:搞清楚 80K 序列里只有 action span 需要算 logits(从 46 GiB 降到 2.3 GiB),再用 branch-parallel 把 pair 的两条 branch 拆到两张卡上。在 2×RTX 3090 上实测 21.05 GiB/卡,余量 2.51 GiB。下一步是用真实 PDB 轨迹做多步训练,验证 loss 收敛和 chosen/rejected margin 变化。