Skip to content
传衡博客
返回

【五】SimPO 训练与 BranchParallel 策略实现

参考资料
  1. SimPO: Simple Preference Optimization with a Reference-Free Reward
  2. Unsloth — Fine-tune Llama 3.1 (70B/8B) and Mistral NeMo (12B) 2-5x faster with 70% less memory
  3. TRL — CPOTrainer / ORPOTrainer 源码
  4. logits_to_keep — transformers 4.45+ (PR #33005)
  5. ORPO: Monolithic Preference Optimization without Reference Model
  6. Unsloth GitHub Issue #1990 — support model parallel training?

两块 RTX 3090(各 24 GiB)上,把 Agent 轨迹的 SimPO 偏好训练从 ORPO 的 24K OOM 推到 80K 跑通,峰值 21.05 GiB/卡。最关键的两步不是什么高级并行技术,而是:搞清楚 80K 序列里只有 4K token 需要算 logits,然后把 chosen/rejected 两条 branch 拆到两张卡上。

这是 TraceForge 系列的第五篇。第四篇讲了 QLoRA、FA2、gradient checkpointing 等基础优化如何让 32K SFT 在单卡 3090 上跑通。本篇在此基础上,解决偏好训练特有的两个问题。

偏好训练为什么比 SFT 贵一倍

SFT 只有一条序列。 给定 prompt + completion,模型做一次前向、一次反向,显存占用就是这一条序列的代价。Qwen3-8B 的 32K SFT 在单卡 3090 上实测 21.3 GiB,用了 4-bit QLoRA + Flash Attention 2 + gradient checkpointing [2],余量舒适。

偏好训练要处理一对序列。 ORPO、SimPO、DPO 这类 pairwise 方法,每个训练样本都包含 chosen 和 rejected 两条 completion。训练器需要同时对两条序列做前向,才能计算偏好 loss。MLP 激活、logits 张量、梯度检查点输入,全部翻倍。

实测边界很清楚。 在同一台机器上,ORPO 单卡能跑 16K,到 24K 就 OOM(差 1.09 GiB)。

ORPO 单卡显存边界
图 1:ORPO partial-loss 单卡边界。16K 可跑(20.70 GiB),24K/32K 均 OOM,first-fail 在 MLP SwiGLU。

但真实数据需要长上下文。 我的场景是训练一个 PDB 调试 Agent,它的轨迹天然很长。9 条真实 debug 轨迹中,5 条超过 32K token,2 条超过 64K,最长的达到 83,720 token。长上下文训练不是在刷数字,而是被任务形态决定的。

真实 debug 轨迹长度分布
图 2:9 条真实 debug_subagent 轨迹的 prompt 长度。5/9 超过 32K,2/9 超过 64K。

走不通的四条路

在找到可行方案之前,我试了四条主流思路,全部失败。

方案原理失败原因
Ulysses SP=2/4沿序列维度切分 attentionORPO batch key 与 Ulysses adapter 不对齐,张量没被切分
Arctic ALSTCPU offload + tiled MLP16K 下两个变体均 OOM(21.07 vs 21.09 GiB),收益倒挂
Ring CP=2Context ParallelismFSDP2 + PEFTprepare() 阶段崩溃,集成太厚
logits_to_keep(错误版)只保留尾部 logitsinput_ids 也被截短,模型只看了 3K token

前三条是”方向对但接不上”。 Ulysses 的 SP 声称启用但张量没变,说明 ORPO 的 batch 构造和 SP adapter 期望的 key 对不上。ALST 在 sdpa 路径上反而更贵。Ring CP 在 FSDP2 + PEFT 的交互层就卡死了。

第四条最隐蔽。 这个 logits_to_keep patch 显存确实降了(从 OOM 降到 52% 占用),但后来发现它不只截了 lm_head 输出,还把 input_idsattention_mask 也裁短了。模型实际只看了尾部 3K token 做 attention,32K 上下文根本没进计算图。这不是显存优化,是悄悄换了一个更短的任务。这个教训后来反过来帮我理解了 logits 的正确用法。

80K 序列里只有 4K token 需要算 logits

这是整件事最关键的认知转变。

一条 Agent 轨迹的结构: 主 Agent 把 question(问题描述)和 test(失败的测试)传给 debug subagent,subagent 多轮交互产生一系列 action(工具调用)和 observation(环境返回)。一条典型轨迹大约 20 轮,最终 prompt 长度从 13K 到 84K 不等。

偏好训练的 score 只应该算在”模型的决策”上。 SimPO 的 avg_logp 表示 completion 部分的平均 log 概率。但”completion”在 Agent 轨迹里到底是什么?question 和 test 是任务输入,observation 是环境的确定性输出,这些都不是模型的决策。只有 action span 才代表模型真正需要学习的行为。

我的三次尝试:

第一版把 question + test 当 prompt,其余全部当 completion。一条 80K 的轨迹,question + test 可能只占 5K,那 “completion” 就有 75K token,每个都需要 lm_head 投影到 151936 维 vocab。单个 branch 的 logits 张量就要 75K × 151936 × 2B ≈ 22 GiB,直接 OOM。

第二版把 observation 也从 completion 里去掉,只保留 agent 的 action span。从实际轨迹统计看,9 条轨迹的 action token 总量只有 16,525,平均每次 API 调用约 86 token。一条轨迹的全部 action span 加起来大约 2K~5K token。logits 张量从 22 GiB 骤降到 4K × 151936 × 2B ≈ 1.2 GiB,直接可行。

第二版成功了。

这不是一个”显存优化技巧”。 logits_to_keep 本身是 transformers 4.45+ 的标准功能 [4],标准 SFT 和偏好训练里早就在用。真正的问题不是”用不用 logits_to_keep”,而是”在 Agent 轨迹里,completion 的边界在哪里”。

# 错误:75K completion → 22 GiB logits → OOM
labels = mark_everything_except_question_as_completion(input_ids)

# 正确:4K action span → 1.2 GiB logits → OK
labels = mark_only_action_spans_as_completion(input_ids)
# 其余位置 labels = -100,前向仍看全文,loss 只落在 action token 上

模型仍然对完整 80K token 做 attention。 logits_to_keep 不截断 input_ids,不改变 attention 计算。模型看到完整的 question、test、所有 observation、所有历史 action,但只在最后的 lm_head 阶段对 action span 的位置做 vocab 投影。上下文信息完整保留,logits 开销降到最低。

从 pair 结构找突破口

定义好 completion 边界后,logits 不再是瓶颈。但偏好训练还有第二个问题:一对序列的 MLP 激活和梯度检查点输入仍然翻倍。

回头看数据就会发现:单条 32K branch 本身是能训的(SFT 21.3 GiB),偏好训练爆显存是因为两条 branch 同时在一张卡上。瓶颈在 pair 结构,不在单条 branch。

为什么不用 Context Parallelism

Context Parallelism(CP)沿序列维度切分,看起来是长上下文的自然选择。但在 2 卡偏好训练这个场景下,它的收益和 branch-parallel 一样,工程代价却高得多。

先算显存。 CP=2 把每条 80K 序列切成两个 40K,但每张卡仍然要处理 chosen + rejected 两条序列。每卡 MLP 激活:2 branches × 40K = 80K token。Branch-parallel 每卡只处理一条完整 80K 序列。每卡 MLP 激活:1 branch × 80K = 80K token。两种方案在 2 卡上的 per-GPU 激活负载完全相同。

再看通信。 CP 需要每一层 attention 都做 KV 交换(ring attention),Qwen3-8B 有 36 层,每步要做 36 × 2 branches = 72 次跨卡通信。Branch-parallel 全程只交换 2 个标量 score + 一次 all-reduce LoRA 梯度。

最后看集成。 Unsloth 截至 2026 年初仍不支持 CP/SP,只提供 DDP [6]。CP 需要 FSDP2 或 DeepSpeed 与 PEFT 的深度集成,而我实测 Ring CP=2 在 FSDP2 + PEFTprepare() 阶段直接崩溃。即使 CP 的集成问题被修复,在 2 卡偏好训练的场景下也不会比 branch-parallel 更省显存。

维度Context Parallelism (CP=2)Branch-Parallel
每卡 MLP 激活2 branches × 40K = 80K token1 branch × 80K = 80K token
每步跨卡通信72 次 KV 交换(36 层 × 2 branches)2 个标量 + 1 次 LoRA grad all-reduce
与 unsloth/QLoRA 集成不支持,需 FSDP2 + PEFT直接可用,手动 gloo 同步
实现复杂度ring attention + 分布式框架 plumbing约 50 行 PyTorch 代码
CP vs Branch-Parallel 对比
图 3:CP=2 和 Branch-Parallel 在 2 卡 80K 场景下的对比。激活负载相同,通信和实现复杂度差距显著。

CP 在更多卡(4+)或更长序列(160K+)的场景下有价值,因为它可以和 branch-parallel 正交组合:4 卡 = 2 CP × 2 branch。但在 2 卡 80K 这个约束下,branch-parallel 是更简单、通信更轻、集成更薄的选择。

选择 SimPO 的原因

SimPO 的 loss 函数天然适合拆分。 它的目标是纯对称的 pairwise 偏好 [1]

L=logσ(β(aˉcaˉr)γ)L = -\log \sigma\bigl(\beta(\bar{a}_c - \bar{a}_r) - \gamma\bigr)

其中 aˉc\bar{a}_caˉr\bar{a}_r 分别是 chosen 和 rejected 的 completion-only 平均 log 概率。对参数 θ 求导,链式法则给出:

Lθ=Laˉcaˉcθ+Laˉraˉrθ\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial \bar{a}_c}\frac{\partial \bar{a}_c}{\partial \theta} + \frac{\partial L}{\partial \bar{a}_r}\frac{\partial \bar{a}_r}{\partial \theta}

两项分别只依赖各自 branch 的前向计算,可以在不同 GPU 上独立完成。DPO 需要额外的 reference model(4 次前向),ORPO 的 chosen 侧有额外 NLL 项(不对称),SimPO 是三者中拆分代价最低的。

Branch-Parallel 协议

核心只有 5 步:

# rank 0: chosen branch, rank 1: rejected branch
score = forward(branch_data)           # 各自前向,互不干扰
scores = all_gather(score)             # 交换两个标量
loss = simpo_loss(scores, rank)        # 各自用 detach 对方的 score 算 loss
loss.backward()                        # 各自反向,梯度只流过自己的 branch
all_reduce_lora_grads(model, SUM)      # 合并梯度 → optimizer.step()

不用 DDP。 unsloth 每卡独立加载模型,梯度同步用手动 dist.all_reduce(grad, op=SUM),没有框架自动 average 的缩放问题。两卡从同一初始参数出发,收到相同的合并梯度,每步更新后参数保持一致。

用 gloo 绕过 NCCL。 每个 worker 通过 CUDA_VISIBLE_DEVICES 只看到一块 GPU(cuda:0),但 NCCL 用 rank 作 device ordinal,rank 1 会尝试访问不存在的 cuda:1。gloo 走 CPU 通信,对于只交换标量 score + LoRA 梯度的场景,延迟可忽略。

等价性已验证。 短序列固定种子下,branch-parallel 与标准单卡 CPOTrainer 的 loss、梯度余弦相似度 0.9998,最大参数差异 1.5e-6。

爬坡结果:32K 到 80K

从 32K 开始逐步推高序列长度。基础配置统一为:Qwen3-8B + 4-bit QLoRA + unsloth + FA2 + gradient checkpointing + 双卡 branch-parallel + logits_to_keep(只对 action span 算 logits)。

序列长度实际 tokens步数峰值 GiB/卡余量 GiB秒/步状态
32K32,768113.619.9545Pass
64K65,5361019.014.55130Pass
72K73,728520.313.25159Pass
80K81,851121.052.51194Pass
90KOOMFail
Branch-Parallel 显存前沿
图 4:Branch-Parallel SimPO 显存前沿。64K 稳定通过 10 步,72K 通过 5 步,80K 通过 1 步。90K OOM。

显存增长亚线性。 32K 到 64K 序列翻倍,显存只增长 1.40x(13.61 → 19.01 GiB),说明 unsloth 的 gradient checkpointing 在长序列下效率很高。

80K 是当前实用上限。 81,851 token(77K prompt + 4K completion),峰值 21.05 GiB(占总显存 89.4%),吞吐 422 tokens/s/GPU。

72K 运行的 GPU 监控数据(2 秒采样,195 个样本点,5 步全程无异常):

指标GPU 0 (chosen)GPU 1 (rejected)
峰值显存22,870 MiB22,870 MiB
平均显存22,870 MiB22,870 MiB
平均 GPU 利用率99.0%98.5%
峰值 GPU 利用率100%100%
平均功耗339.5 W342.4 W
峰值温度67°C67°C

两卡显存全程钉在峰值,利用率接近 100%,不是 bursty 的假成功。

90K 的 OOM 是碎片问题。 可用碎片 2.59 GiB 大于需要申请的连续块 2.11 GiB,但 PyTorch 分配器找不到足够大的连续空间。first-fail 始终在 MLP SwiGLU(swiglu_fg_kernel)。

各层优化的贡献

80K 偏好训练的显存由六个部分组成。下表列出每个部分在有/无对应优化时的大小。

显存组成无优化时优化后用了什么
模型权重16.4 GiB (bf16)4.6 GiB4-bit QLoRA
优化器状态98.3 GiB (Adam bf16)0.3 GiBQLoRA(只更新 LoRA 参数)
Attention 中间结果28.8 TiB (O(n²) score 矩阵)≈0Flash Attention 2
MLP / 层间激活266 GiB30 GiBGradient Checkpointing (unsloth)
梯度张量16.4 GiB (bf16 全参)≈0(CPU)Gradient Offload (unsloth)
Logits 张量46.4 GiB (全量 80K × vocab × 2 branches)2.3 GiBlogits_to_keep(只对 action span 投影)
双 branch 叠加以上全部 ×2 / 单卡÷2 / 每卡Branch-Parallel 双卡

QLoRA、FA2、gradient checkpointing、gradient offload 是标准实践,不是本文的贡献。 任何做长上下文 LoRA 训练的人都会用这些。它们把 baseline 压到一个可工作的起点,但偏好训练还剩两个问题需要单独解决:

第一,logits 该算到哪里。 80K 序列的全量 logits 要 46.4 GiB(两个 branch),但只对 action span 算 logits 就只需要 2.3 GiB。差距来自对 completion 边界的定义:如果把 observation 也当 completion,显存直接爆炸;如果只算 action,轻松通过。

第二,pair 结构怎么拆。 偏好训练必须同时处理 chosen 和 rejected,单卡放不下就拆到两卡。SimPO 的对称 loss 函数让梯度可以精确拆分,两卡只需交换两个标量 score,然后 all-reduce LoRA 梯度。

小结

把 Agent 轨迹的偏好训练推到 80K(77K prompt + 4K action),核心靠两件事:搞清楚 80K 序列里只有 action span 需要算 logits(从 46 GiB 降到 2.3 GiB),再用 branch-parallel 把 pair 的两条 branch 拆到两张卡上。在 2×RTX 3090 上实测 21.05 GiB/卡,余量 2.51 GiB。下一步是用真实 PDB 轨迹做多步训练,验证 loss 收敛和 chosen/rejected margin 变化。



Previous Post
【六】长上下文 SFT 与双卡 BranchParallel + SimPO 代码实现
Next Post
【四】SFT 微调