Skip to content
传衡博客
返回

【六】长上下文 SFT 与双卡 BranchParallel + SimPO 代码实现

参考资料
  1. Unsloth — GitHub
  2. TRL — CPOTrainer 源码
  3. TRL — SFTTrainer 源码
  4. transformers — logits_to_keep (PR #33005)
  5. QLoRA: Efficient Finetuning of Quantized LLMs
  6. SimPO: Simple Preference Optimization with a Reference-Free Reward

这是 TraceForge 系列的第六篇,侧重代码实现。第四篇讲了四层显存优化的原理,第五篇讲了偏好训练特有的 completion 边界和双卡拆 pair。本篇把这两篇的思路落到可运行的代码上:先用 TRL 的标准 API 跑通 32K SFT,再从 CPOTrainer 出发手写 80K Branch-Parallel SimPO。

第一部分:32K SFT

模型加载:unsloth + 4-bit QLoRA

unsloth [1] 封装了模型加载和 LoRA 注入。核心是两个调用:

import torch
from unsloth import FastLanguageModel

# CUBLAS 半精度 bug 修复(torch 2.10 + CUDA 12.x + Ampere GPU)
torch.backends.cuda.preferred_blas_library("cublaslt")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="qwen3-8b-unsloth-bnb-4bit",  # 预量化权重
    max_seq_length=32768,
    load_in_4bit=True,                         # NF4 量化
    dtype=torch.bfloat16,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16, lora_alpha=16, lora_dropout=0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",      # attention
        "gate_proj", "up_proj", "down_proj",           # MLP (SwiGLU)
    ],
    use_gradient_checkpointing="unsloth",  # unsloth 自定义的细粒度 GC
    random_state=42,
)scripts/finetune_qwen3_32k.py

load_in_4bit=True 调用 bitsandbytes 的 NF4 量化,把 8B 参数从 16.4 GiB(bf16)压到 4.6 GiB。use_gradient_checkpointing="unsloth" 启用 unsloth 在每个 transformer 层内部做的细粒度 checkpointing,比 PyTorch 标准的 torch.utils.checkpoint 更省显存。

target_modules 包含全部 7 个线性层。 Qwen3 的 MLP 用 SwiGLU 结构(gate_proj + up_proj → SiLU → down_proj),把这三层也加进 LoRA target 能覆盖更多参数路径。实测 r=16 时可训练参数 44M,占总参数 0.53%。

torch.backends.cuda.preferred_blas_library("cublaslt") 是一个必须加的 workaround。torch 2.10 + CUDA 12.8 在 Ampere GPU(RTX 3090, SM 8.6)上有个 cublasGemmEx 半精度 bug:torch.mm()torch.matmul() 在 bf16/fp16 下直接报 CUBLAS_STATUS_INVALID_VALUE,但 torch.addmm() 正常(走不同代码路径)。强制切换到 CUBLAS-LT API 可以绕过。这个 bug 不会给任何有意义的错误提示,只有一个泛化的 status code,排查了三天才定位。

数据格式化:Agent 轨迹 → 训练样本

真实的 debug subagent 轨迹存储为 JSON,每条消息有 rolecontent、可选的 reasoning_contentextra.actions。需要序列化为模型能消费的文本格式:

def serialize_history(messages, stop_idx):
    """将轨迹前 stop_idx 条消息序列化为上下文文本"""
    parts = []
    for idx, msg in enumerate(messages[:stop_idx]):
        role = msg.get("role", "")
        if role in {"system", "user"}:
            parts.append(f"[{role.upper()} #{idx}]\n{msg['content']}")
        elif role == "assistant":
            reasoning = (msg.get("reasoning_content") or "").strip()
            actions = msg.get("extra", {}).get("actions", [])
            payload = []
            if reasoning:
                payload.append("Reasoning:\n" + reasoning)
            for a in actions:
                payload.append(f"Action:\n{a['function']['name']}({a['function']['arguments']})")
            parts.append(f"[ASSISTANT #{idx}]\n" + "\n\n".join(payload))
    return "\n\n".join(parts)scripts/finetune_qwen3_32k_debug_traj.py

每个 assistant turn 生成一个训练样本: 把它之前的所有消息(system、user、历史 assistant + observation)序列化为 prompt,当前 assistant 的 action 作为 completion。然后用 tokenizer 的 chat template 组装:

prompt_messages = [
    {"role": "system", "content": "You are learning from a debug trajectory."},
    {"role": "user",   "content": f"Case: {case_id}\n{history}\nPredict next action."},
]
completion_messages = [
    {"role": "assistant", "content": target},
]

full_text = tokenizer.apply_chat_template(
    prompt_messages + completion_messages,
    tokenize=False,
    add_generation_prompt=False,
)
dataset.append({"text": full_text})scripts/finetune_qwen3_32k_debug_traj.py

SFT 训练:TRL SFTTrainer

数据准备好后,训练部分完全是标准 TRL [3]

from trl import SFTConfig, SFTTrainer

config = SFTConfig(
    output_dir=str(output_dir),
    max_seq_length=32768,
    per_device_train_batch_size=1,       # 单卡,batch=1
    gradient_accumulation_steps=4,        # 有效 batch=4
    max_steps=20,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=1,
    packing=False,     # 32K 长序列不 pack
    report_to="none",  # 离线训练,不上报 wandb
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    args=config,
    formatting_func=lambda example: example["text"],
)

trainer.train()scripts/finetune_qwen3_32k_debug_traj.py

packing=False 是关键。 TRL 的 SFTTrainer 默认会把多个短样本拼接成一条长序列(packing),这对短文本效率很高,但对 32K 长序列会导致截断。关闭 packing 后每条样本独立处理。

formatting_func 直接返回预拼好的文本。TRL 会自动 tokenize 并构造 labels(整条序列都算 loss)。如果想只对 completion 部分算 loss,可以用 DataCollatorForCompletionOnlyLM,但 SFT baseline 阶段我们选择全序列训练。

第二部分:SimPO 偏好训练

第一步:用 CPOTrainer 建立单卡 oracle

在写自定义训练循环之前,先用 TRL 的 CPOTrainer [2] 跑一个标准 SimPO 作为对照组。后续所有自定义实现都要和这个 oracle 对齐。

from trl import CPOConfig, CPOTrainer

config = CPOConfig(
    output_dir=str(output_dir),
    loss_type="simpo",       # SimPO loss: L = -log σ(β(ā_c - ā_r) - γ)
    cpo_alpha=0.0,           # 纯 SimPO,不加 chosen-side NLL
    beta=2.0,                # 温度参数
    simpo_gamma=0.5,         # target reward margin
    max_length=512,
    max_prompt_length=384,
    per_device_train_batch_size=1,
    max_steps=3,
    learning_rate=1e-5,
    bf16=True,
    optim="adamw_8bit",
    disable_dropout=True,    # 偏好训练关 dropout
)

trainer = CPOTrainer(
    model=model,
    args=config,
    train_dataset=dataset,    # 含 prompt, chosen, rejected 字段
    processing_class=tokenizer,
)
result = trainer.train()experiments/branch_parallel_pref/simpo_oracle.py

偏好数据格式和 SFT 不同。CPOTrainer 期望每条样本有三个字段:prompt(消息列表)、chosen(消息列表)、rejected(消息列表):

def gen_preference_data(tokenizer, num_samples, max_seq_length):
    samples = []
    for i in range(num_samples):
        prompt = [
            {"role": "system", "content": "你是一个专业的文档分析助手。"},
            {"role": "user",   "content": f"请分析以下内容:\n{doc}"},
        ]
        chosen   = [{"role": "assistant", "content": "高质量回答..."}]
        rejected = [{"role": "assistant", "content": "低质量回答..."}]
        samples.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
    return samplesexperiments/branch_parallel_pref/simpo_oracle.py

loss_type="simpo" + cpo_alpha=0.0 是纯 SimPO 的配置。CPOTrainer 默认 cpo_alpha=1.0 会额外加一个 chosen-side NLL 正则项,变成 CPO 而不是 SimPO。disable_dropout=True 确保 chosen 和 rejected 在同一 dropout mask 下比较。

第二步:Branch-Parallel 的启动方式

标准 DDP 用 torchrun 启动多个进程,但 unsloth 不支持 FSDP/DDP [1]。我们用 subprocess 手动启动两个 worker,每个 worker 通过 CUDA_VISIBLE_DEVICES 只看到一块 GPU:

def launch(args):
    gpu_ids = [g.strip() for g in args.gpus.split(",")]
    assert len(gpu_ids) == 2

    procs = []
    for rank, gpu_id in enumerate(gpu_ids):
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = gpu_id   # 每个 worker 只看到一块 GPU
        env["HF_HUB_OFFLINE"] = "1"

        cmd = [sys.executable, __file__,
               "--_worker", "--_rank", str(rank),
               "--model_path", args.model_path,
               "--port", str(args.port),
               # ... 其他参数 ...
               ]

        log_f = open(output_dir / f"worker_rank{rank}.log", "w")
        p = subprocess.Popen(cmd, env=env, stdout=log_f, stderr=subprocess.STDOUT)
        procs.append((p, log_f))

    # 等待两个 worker 完成
    for rank, (p, log_f) in enumerate(procs):
        ret = p.wait()
        log_f.close()
        if ret != 0:
            sys.exit(1)experiments/branch_parallel_pref/simpo_branch_parallel.py

为什么不用 DDP? 三个原因:(1)unsloth 截至 2026 年初只支持 DDP,不支持 FSDP 或模型并行;(2)NCCL 在 CUDA_VISIBLE_DEVICES 隔离下有 device ordinal 冲突(rank 1 尝试访问不存在的 cuda:1);(3)我们不需要 DDP 的数据并行语义,两张卡处理的是同一个 pair 的不同 branch,不是不同样本。

第三步:Worker 初始化

每个 worker 做三件事:加载模型、初始化分布式通信、同步初始参数。

def worker(args):
    import torch
    import torch.distributed as dist
    from unsloth import FastLanguageModel

    rank = args._rank
    torch.cuda.set_device(0)  # 每个 worker 只看到一块 GPU

    # 1) 加载模型(必须在 dist.init 之前)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_path,
        max_seq_length=args.max_seq_length,
        load_in_4bit=True, dtype=torch.bfloat16,
    )
    model = FastLanguageModel.get_peft_model(model, r=16, ...)
    FastLanguageModel.for_training(model)

    # 2) 初始化 gloo 通信(不用 NCCL)
    dist.init_process_group(
        backend="gloo",
        init_method=f"tcp://127.0.0.1:{args.port}",
        rank=rank, world_size=2,
    )

    # 3) 广播 rank 0 的参数到 rank 1(确保起点一致)
    for name, param in model.named_parameters():
        if param.requires_grad:
            dist.broadcast(param.data, src=0)experiments/branch_parallel_pref/simpo_branch_parallel.py

模型加载必须在 dist.init_process_group 之前。 如果顺序反过来,dist.init 会影响 transformers 内部的 device context,导致 unsloth 加载模型时出错。

backend="gloo" 走 CPU 通信。对于只交换两个标量 score + LoRA 梯度的场景,CPU 延迟可忽略。gloo 不需要 NCCL 的 device ordinal 映射,绕过了 CUDA_VISIBLE_DEVICES 的冲突。

第四步:前向传播 + logits_to_keep

这是长上下文偏好训练最关键的优化。模型对完整序列做 attention,但只对 completion 部分的 token 计算 logits(vocab 投影):

# rank 0 处理 chosen,rank 1 处理 rejected
input_ids = branch_input_ids   # (1, 80K)
labels = branch_labels          # prompt 位置 = -100,completion 位置 = token id

# 计算 completion 有多少 token
n_completion = int((labels != -100).sum().item())  # ~4K
n_keep = n_completion + 2  # 多保留 2 个,供自回归 shift 对齐

# 前向:模型看完整 80K,但 lm_head 只投影最后 n_keep 个位置
outputs = model(
    input_ids=input_ids,
    use_cache=False,
    num_logits_to_keep=n_keep,  # transformers 4.45+ 原生支持
)
logits = outputs.logits  # shape: (1, ~4K, 151936),不是 (1, 80K, 151936)experiments/branch_parallel_pref/simpo_branch_parallel.py

num_logits_to_keep 是 transformers 4.45+(2024.08)加入的参数 [4]。它只作用于 lm_head(最后的 vocab 投影层),不改变模型内部的 attention 和 MLP 计算。80K 序列下,全量 logits 要 80K × 151936 × 2B = 23 GiB;只保留 4K completion 部分则只需 4K × 151936 × 2B = 1.2 GiB

有了 logits 后,用 TRL 的 get_batch_logps 计算 completion-only 的平均 log 概率:

from trl.trainer.cpo_trainer import CPOTrainer

labels_tail = labels[:, -n_keep:]  # 只取最后 n_keep 个位置的 label

avg_logp = CPOTrainer.get_batch_logps(
    logits,
    labels_tail,
    average_log_prob=True,     # SimPO 用平均 log prob
    label_pad_token_id=-100,
    is_encoder_decoder=False,
)
local_score = avg_logp.squeeze()  # 标量,带梯度experiments/branch_parallel_pref/simpo_branch_parallel.py

CPOTrainer.get_batch_logps 是 TRL 的静态方法 [2],内部做自回归 shift(logits[:, :-1] 对 labels[:, 1:])、selective_log_softmax、然后按 labels != -100 的 mask 求平均。我们直接复用它,不重写。

第五步:交换 score + 计算 loss

两个 worker 各自完成前向后,通过 gloo 交换标量 score,然后各自计算 SimPO loss [6]

# 交换 score(只传两个浮点数)
score_cpu = torch.tensor([local_score.item()])  # 断开计算图
all_scores = [torch.zeros(1) for _ in range(2)]
dist.all_gather(all_scores, score_cpu)

score_chosen  = all_scores[0].item()
score_rejected = all_scores[1].item()

# 计算 SimPO loss:L = -log σ(β(ā_c - ā_r) - γ)
import torch.nn.functional as F

if rank == 0:
    # rank 0 对 chosen 有梯度,rejected score 是常数
    loss = -F.logsigmoid(
        args.beta * (local_score - score_rejected) - args.simpo_gamma
    )
else:
    # rank 1 对 rejected 有梯度,chosen score 是常数
    loss = -F.logsigmoid(
        args.beta * (score_chosen - local_score) - args.simpo_gamma
    )experiments/branch_parallel_pref/simpo_branch_parallel.py

local_score.item() 断开计算图。 传给对方的是纯浮点数,没有梯度。但 local_score 本身(GPU 上的 tensor)仍然连着整个前向计算图。所以 rank 0 的 loss.backward() 只回传 chosen branch 的梯度,rank 1 只回传 rejected branch 的梯度。两边加起来等价于单卡同时算两条 branch 的总梯度。

第六步:梯度同步 + 参数更新

# 各自反向传播
loss.backward()

# AllReduce(SUM) LoRA 梯度
for param in model.parameters():
    if param.requires_grad and param.grad is not None:
        cpu_grad = param.grad.detach().float().cpu()
        dist.all_reduce(cpu_grad, op=dist.ReduceOp.SUM)
        param.grad = cpu_grad.to(param.device).to(param.grad.dtype)

# 统一更新(两个 worker 收到相同梯度,更新后参数保持一致)
optimizer.step()
optimizer.zero_grad()experiments/branch_parallel_pref/simpo_branch_parallel.py

CPU 中转是 gloo 的限制。 gloo 不能直接操作 GPU tensor,需要先搬到 CPU 做 all_reduce 再搬回。对于 LoRA 参数(rank=16,每层约 12M 参数),这个开销可以忽略。如果是全参数训练(8B),CPU 中转会成为瓶颈。

ReduceOp.SUM 而不是 MEAN 因为每个 rank 的梯度只包含一条 branch 的贡献,SUM 后等于完整的 SimPO 梯度。如果用 MEAN 会导致梯度幅度减半。

第三部分:GPU 监控

训练过程中用后台线程每 2 秒采一次 nvidia-smi,记录显存、利用率、功耗和温度:

def build_monitor_thread(output_dir, gpu_ids, interval_s=2):
    stop_event = threading.Event()
    csv_path = output_dir / "nvidia_smi_monitor.csv"

    def monitor():
        with csv_path.open("w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["ts", "elapsed_s", "gpu_index",
                             "memory_used_mib", "memory_total_mib",
                             "util_gpu", "util_mem", "power_draw_w", "temp_c"])
            t0 = time.time()
            while not stop_event.is_set():
                proc = subprocess.run(
                    ["nvidia-smi",
                     "--query-gpu=index,memory.used,memory.total,"
                     "utilization.gpu,utilization.memory,power.draw,temperature.gpu",
                     "--format=csv,noheader,nounits"],
                    capture_output=True, text=True,
                )
                for line in proc.stdout.strip().splitlines():
                    parts = [x.strip() for x in line.split(",")]
                    if parts[0] in gpu_ids:
                        writer.writerow([time.strftime("%H:%M:%S"),
                                         round(time.time() - t0, 2)] + parts)
                f.flush()
                stop_event.wait(interval_s)

    thread = threading.Thread(target=monitor, daemon=True)
    return stop_event, threadscripts/finetune_qwen3_32k_debug_traj.py

训练开始前 thread.start(),训练结束后 stop_event.set()。输出的 CSV 可以直接用 pandas 做后续分析(峰值显存、利用率曲线等)。

标准 TRL 代码 vs 自定义代码

环节标准 TRL/UnslothBranch-Parallel 自定义
模型加载FastLanguageModel.from_pretrained()相同,但必须在 dist.init 之前
LoRA 配置FastLanguageModel.get_peft_model()相同
SFT 训练SFTTrainer不涉及
偏好训练CPOTrainer(loss_type="simpo")手写 forward/backward + allreduce
avg_logp 计算CPOTrainer 内部调用复用 CPOTrainer.get_batch_logps
logits 优化CPOTrainer 不原生支持手动传 num_logits_to_keep
多卡通信DDP(unsloth 仅支持这一种)subprocess + gloo 手动同步
梯度同步DDP 自动 all-reduce手动 dist.all_reduce(SUM)

自定义的部分只有三处: logits_to_keep 的正确使用、gloo score 交换、手动梯度 all-reduce。其余全部复用 unsloth 和 TRL 的标准组件。

小结

32K SFT 用标准 TRL SFTTrainer 即可,关键是 packing=False 和正确的 chat template 格式化。80K SimPO 需要跳出 CPOTrainer,但核心公式(get_batch_logps、SimPO loss)仍然复用 TRL 源码。真正需要手写的只有 logits_to_keep 传参、两个标量的 gloo 交换、以及 LoRA 梯度的 all-reduce。完整代码见 GitHub 仓库



Previous Post
【七】基础RL:从策略梯度定理到 PPO 算法
Next Post
【五】SimPO 训练与 BranchParallel 策略实现