Skip to content
传衡博客
返回

【八】离线RL:DPO 与 SimPO 的推导与代码实现

参考资料
  1. Training language models to follow instructions with human feedback
  2. Direct Preference Optimization: Your Language Model is Secretly a Reward Model
  3. SimPO: Simple Preference Optimization with a Reference-Free Reward

DPO 是目前工业界最主流的离线偏好优化方法,把 RLHF 的四模型在线 RL 循环压缩成一个二分类 loss,工程上几乎零额外开销。SimPO 再往前一步,连 reference model 也砍掉,训练速度提升 20%、峰值显存降低 10%[3]

效果上,DPO 在 TL;DR 摘要任务上达到 61% GPT-4 胜率,超过 PPO 的 57%[2]。SimPO 在 AlpacaEval 2 和 Arena-Hard 上进一步把胜率分别拉高 6.4%7.5%[3]

这篇文章从 RLHF 目标出发,推导 DPO 和 SimPO 的公式,解释”为什么能消掉 reward model”,并用代码走一遍 SimPO 的完整计算流程。

回顾:RLHF 目标与 PPO 实现

RLHF 的总体目标

RLHF 的目标是训练一个 policy,让它既能拿到高的 reward,又不要偏离 reference 太远:

maxπExD,yπ(yx)[Rϕ(x,y)]βExD[DKL(π(yx)πref(yx))]\max_{\pi} \, \mathbb{E}_{x \sim D, y \sim \pi(y|x)} \left[ R_\phi(x, y) \right] - \beta \mathbb{E}_{x \sim D} \left[ D_{\text{KL}}(\pi(y|x) \, \| \, \pi_{\text{ref}}(y|x)) \right]

其中:

这是 RLHF 的”问题定义”:我们想要达成的目标是这样的。

PPO 的具体实现

PPO 是用来实现上述目标的一种方法。PPO 的完整 objective 由 三项 组合成一个总 loss:

Ltotal=Lpolicy+c1Lvaluec2Lentropy\mathcal{L}^{\text{total}} = \mathcal{L}^{\text{policy}} + c_1 \mathcal{L}^{\text{value}} - c_2 \mathcal{L}^{\text{entropy}}

三项各自的作用

组件公式作用
Policy lossLpolicy=Et[min(rtA^t,clip(rt)A^t)]\mathcal{L}^{\text{policy}} = -\mathbb{E}_t[\min(r_t \hat{A}_t, \text{clip}(r_t)\hat{A}_t)]更新策略,使高 advantage 动作的概率提升。前面带负号是因为要用梯度下降最小化 loss
Value lossLvalue=(Vθ(st)Vttarg)2\mathcal{L}^{\text{value}} = (V_\theta(s_t) - V_t^{\text{targ}})^2训练 critic 准确预测 value,让 advantage 估计更准
Entropy bonusLentropy=Et[aπθ(ast)logπθ(ast)]\mathcal{L}^{\text{entropy}} = -\mathbb{E}_t[\sum_a \pi_\theta(a \mid s_t) \log \pi_\theta(a \mid s_t)]鼓励探索,防止策略过早收敛到局部最优

其中 rt=πθ(atst)πθold(atst)r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} 是重要性采样比率,c10.5c_1 \approx 0.5c20.01c_2 \approx 0.01 是权重系数。

关键点:PPO 只有 一个总 loss, backward 时同时产生 policy、value、entropy 三个梯度,一次更新 policy network 和 critic network(如果它们共享 backbone,梯度会累加到共享层)。

PPO 完整训练流程

  1. Rollout:用当前策略采样一批轨迹
  2. 计算 advantage(用 GAE)
  3. 多次更新:同一批数据重复算总 loss,反复更新参数
  4. 重复

PPO 存在的问题

PPO 的问题:需要 4 个模型(Policy, Reference, RM, Critic)、在线采样、训练不稳定。

DPO 的思路:既然 RLHF 的目标已经很明确,我们能不能从 RLHF 目标出发,通过数学推导直接得到一个新的 loss?这样就不需要 PPO 那套复杂的在线 RL 机制了。

关键洞察:如果我们知道最优 policy,就可以反推出 reward 的表达式,从而消掉 reward model。

DPO 推导:四步闭式消元

下面的推导过程不需要记,真正需要记住的是:

  • 核心思想:RLHF 目标可以推导成 policy log ratio 的二分类 loss
  • 最终公式:DPO loss 的样子
  • 直觉理解:DPO 是怎么工作的

DPO 的核心发现:在 Bradley-Terry 偏好模型假设下,可以从 KL 约束的 RL 目标中解析解出最优 policy,然后反推出 reward 的表达式,最终消掉 reward model。

Step 1:写出 KL 约束优化的拉格朗日形式

对于固定 prompt xx,优化目标是:

maxπEyπ[R(x,y)]βDKL(ππref)\max_{\pi} \, \mathbb{E}_{y \sim \pi}\left[ R(x, y) \right] - \beta \, D_{\text{KL}}(\pi \| \pi_{\text{ref}})

展开 KL 散度:

=maxππ(yx)R(x,y)dyβπ(yx)logπ(yx)πref(yx)dy= \max_{\pi} \int \pi(y|x) R(x, y) \, dy - \beta \int \pi(y|x) \log \frac{\pi(y|x)}{\pi_{\text{ref}}(y|x)} \, dy

Step 2:求解闭式最优策略

这是一个带约束的变分优化问题。用拉格朗日乘子法,约束是 π(yx)dy=1\int \pi(y|x) \, dy = 1(概率归一化)。

构造拉格朗日函数:

L(π,λ)=π(yx)[R(x,y)βlogπ(yx)πref(yx)]dyλ(π(yx)dy1)\mathcal{L}(\pi, \lambda) = \int \pi(y|x) \left[ R(x, y) - \beta \log \frac{\pi(y|x)}{\pi_{\text{ref}}(y|x)} \right] dy - \lambda \left( \int \pi(y|x) dy - 1 \right)

π(yx)\pi(y|x) 求变分导数并令其为零:

R(x,y)βlogπ(yx)πref(yx)βλ=0R(x, y) - \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} - \beta - \lambda = 0

整理得:

logπ(yx)πref(yx)=R(x,y)ββ+λβ\log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} = \frac{R(x, y)}{\beta} - \frac{\beta + \lambda}{\beta}

指数化:

π(yx)=1Z(x)πref(yx)exp(1βR(x,y))\pi^*(y|x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y|x) \exp\left( \frac{1}{\beta} R(x, y) \right)

其中 Z(x)=exp(β+λβ)Z(x) = \exp\left( \frac{\beta + \lambda}{\beta} \right) 是归一化常数(partition function),确保 π(yx)dy=1\int \pi^*(y|x) dy = 1

直观理解:最优策略在 reference 的基础上,按 exponentiated reward 重新加权。reward 越高的回答,概率提升越多。

Step 3:反解 reward

从 Step 2 的结果反解 R(x,y)R(x, y)

R(x,y)=βlogπ(yx)πref(yx)+βlogZ(x)R(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x)

关键洞察:最优 reward 可以用最优 policy 和 reference policy 的 log ratio 表示!

这意味着:如果我们知道最优 policy,就不需要显式的 reward model。

Step 4:代入 Bradley-Terry 偏好模型得到 DPO Loss

人类偏好数据通常是成对比较:对于同一个 prompt xx,比较两个回答 ywy_w(win)和 yly_l(lose)。

Bradley-Terry 模型假设偏好概率为:

P(ywylx)=σ(R(x,yw)R(x,yl))P(y_w \succ y_l | x) = \sigma(R(x, y_w) - R(x, y_l))

其中 σ(z)=11+ez\sigma(z) = \frac{1}{1 + e^{-z}} 是 sigmoid 函数。

把 Step 3 的 reward 表达式代入:

R(x,yw)R(x,yl)=βlogπ(ywx)πref(ywx)βlogπ(ylx)πref(ylx)R(x, y_w) - R(x, y_l) = \beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi^*(y_l|x)}{\pi_{\text{ref}}(y_l|x)}

注意 βlogZ(x)\beta \log Z(x) 被消掉了!这是 DPO 能work的关键——不需要知道归一化常数。

最终得到 DPO loss:

LDPO(πθ;πref)=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]

直觉理解

公式 说明:

  1. 不需要 reward model:reward 被隐式地表示为 policy 和 reference 的 log ratio
  2. 不需要在线采样:用离线偏好数据 (x,yw,yl)(x, y_w, y_l) 直接训练
  3. 不需要 PPO:简单的二分类 loss

DPO 的隐式 reward 与梯度

定义 DPO 的 隐式 reward

rθ(x,y)=βlogπθ(yx)πref(yx)r_\theta(x, y) = \beta \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}

对 DPO loss θ\theta 求导,看梯度更新方向:

θLDPO=E[σ(Δr)β(θlogπθ(ywx)θlogπθ(ylx))]\nabla_\theta \mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \sigma(-\Delta r) \cdot \beta \left( \nabla_\theta \log \pi_\theta(y_w|x) - \nabla_\theta \log \pi_\theta(y_l|x) \right) \right]

其中 Δr=rθ(x,yw)rθ(x,yl)\Delta r = r_\theta(x, y_w) - r_\theta(x, y_l) 叫做 margin(差距/裕量),表示 chosen 和 rejected 的 reward 差值。

为什么叫 margin:就像 SVM 里的间隔一样,margin 越大,模型越确信 chosen 比 rejected 好。当 margin > 0 时,模型判断正确;margin 越大,loss 越小,梯度越小(已经学好了)。

梯度效果

数值示例:从模型输出到 Loss

πθ(yx)\pi_\theta(y|x) 是模型对序列 yy 的条件概率。对于自回归模型,它等于每个 token 条件概率的连乘:

πθ(yx)=t=1yπθ(ytx,y<t)\pi_\theta(y|x) = \prod_{t=1}^{|y|} \pi_\theta(y_t | x, y_{<t})

取 log 后变成求和,这就是 logπθ(yx)=tlogπθ(ytx,y<t)\log \pi_\theta(y|x) = \sum_{t} \log \pi_\theta(y_t | x, y_{<t}) 的来源。

具体计算流程

假设 β=0.1\beta = 0.1,prompt 是”请解释量子计算”,模型前向传播后:

步骤操作Chosen (y_w)Rejected (y_l)
1模型输出 logitsshape [32, 151936],每个位置是 vocab 上的分数shape [28, 151936]
2Log softmax 得 log prob每个 token 的 log prob 在 [-10, 0] 之间类似
3关键 token 示例”量子”位置的 log prob = -0.95
”比特”位置的 log prob = -0.72
”量子”位置的 log prob = -1.50
”原理”位置的 log prob = -0.85
4求和得 logπθ\sum \log \pi_\theta-12.8 (32个token)-18.5 (28个token)
5平均得 Avglogπθ\text{Avg} \log \pi_\theta-0.40-0.66
6Reference 模型同样计算Avglogπref\text{Avg} \log \pi_{\text{ref}} = -0.44Avglogπref\text{Avg} \log \pi_{\text{ref}} = -0.64
7Log ratio+0.04-0.02

为什么选择”量子”这个 token 很重要

模型对 chosen 中的”量子”预测更自信(-0.95 vs -1.50),因为 chosen 的回答”量子计算利用量子比特…”更符合训练数据的分布。这体现在:

从 log ratio 到 loss

计算项ChosenRejected说明
Log ratio+0.04-0.02policy - reference
隐式 reward (β×\beta \times ratio)+0.004-0.002乘以 0.1
MarginΔr=0.004(0.002)=\Delta r = 0.004 - (-0.002) =0.006差距很小,模型还没学好
Sigmoidσ(0.006)\sigma(0.006) \approx0.5015接近 0.5,不太确定
Losslog(0.5015)-\log(0.5015) \approx0.689还有优化空间

直观感受这些数字

本轮训练前后对比(模型已经过若干轮训练,不是初始 SFT 状态):

本轮训练前 Chosen本轮训练前 Rejected本轮训练后 Chosen本轮训练后 Rejected
Avglogπθ\text{Avg} \log \pi_\theta-0.40-0.66-0.35-0.70
Avglogπref\text{Avg} \log \pi_{\text{ref}}-0.44-0.64-0.44-0.64(frozen)
logπθπref\log \frac{\pi_\theta}{\pi_{\text{ref}}}+0.04-0.02+0.09-0.06
隐式 reward+0.004-0.002+0.009-0.006

本轮训练的效果:πθ(ywx)\pi_\theta(y_w|x) 从 -0.40 提升到 -0.35(概率增大),πθ(ylx)\pi_\theta(y_l|x) 从 -0.66 降低到 -0.70(概率减小),margin 从 0.006 变成 0.015,loss 从 0.689 降低到 0.556。

初始状态(Step 0):如果是刚初始化的 SFT 模型(πθ=πref\pi_\theta = \pi_{\text{ref}}),则 chosen 和 rejected 的 log ratio 都是 0,margin = 0,loss = -log σ(0) = 0.693。表格展示的是训练过程中某一轮的改进。

DPO 训练过程拆解

DPO 的计算流程比 PPO 简单很多,不需要在线采样、不需要 critic、不需要 rollout。假设我们有一对偏好数据,用具体数值演示整个计算循环。

输入示例

流程概览

Step 1: Tokenize —— 分词与对齐

# Tokenize prompt, chosen, rejected
prompt_tokens = tokenizer("请解释量子计算")["input_ids"]  # [8]
chosen_tokens = tokenizer("请解释量子计算" + chosen)["input_ids"]  # [40] = 8 + 32
rejected_tokens = tokenizer("请解释量子计算" + rejected)["input_ids"]  # [36] = 8 + 28

prompt_len = len(prompt_tokens)  # 8
chosen_len = len(chosen_tokens)  # 40
rejected_len = len(rejected_tokens)  # 36

# Batch tensor for forward pass (with padding)
input_ids = tokenizer(
    [chosen, rejected],
    padding=True,
    return_tensors="pt"
)["input_ids"]  # [2, 40],rejected 会被 padding 到 40

Shape 解释

关键位置的 token

序列位置token中文说明
chosen03891prompt 开头
chosen89123量子chosen_response 开头
chosen39151643序列结束符
rejected89123量子rejected_response 开头
rejected35151643序列结束符

Step 2: Policy 前向传播与 Log Probs 提取

# Policy 前向传播
outputs = policy(input_ids)
logits = outputs.logits[:, :-1, :]  # [2, 39, 151936],预测下一个 token

# Log softmax
log_probs = F.log_softmax(logits, dim=-1)  # [2, 39, 151936]

# 提取实际 token 对应的 log prob
labels = input_ids[:, 1:]  # [2, 39],目标 token
policy_logp = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # [2, 39]

# 创建 completion mask(只计算 response 部分,不计算 prompt)
completion_mask = torch.zeros_like(policy_logp, dtype=torch.bool)
completion_mask[:, prompt_len-1:] = True  # [2, 39]

# Masked log probs(只保留 response 部分)
policy_logp_masked = policy_logp * completion_mask  # [2, 39]

Shape 解释

关键位置的数值(chosen 响应部分):

位置token ID中文对应概率policy_logp含义
715234计算0.22-1.50模型有 22% 置信度预测”计算”是下一个 token
89123量子0.39-0.95模型有 39% 置信度预测”量子”是下一个 token
154532比特0.49-0.72模型有 49% 置信度预测”比特”是下一个 token
209876经典0.56-0.58模型有 56% 置信度预测”经典”是下一个 token
381516430.26-1.35模型有 26% 置信度预测结束

logp 的意义:对数概率。logp 越接近 0,模型越”确信”这个 token 是正确的。指数后得到概率:exp(-0.72) ≈ 0.49

Step 3: 计算 Chosen 和 Rejected 的 Response Log Prob

# Sum over sequence dimension(只计算 response 部分)
policy_sum_logp = policy_logp_masked.sum(dim=-1)  # [2]

# Count completion tokens
completion_lengths = completion_mask.sum(dim=-1)  # [2]
# chosen_lengths = 32, rejected_lengths = 28

# Split chosen and rejected
chosen_sum_logp = policy_sum_logp[0]  # scalar, 比如 -12.8
rejected_sum_logp = policy_sum_logp[1]  # scalar, 比如 -18.5

chosen_len = completion_lengths[0]  # 32
rejected_len = completion_lengths[1]  # 28

# Average log prob(可选,DPO 可以用 sum 或 average)
chosen_avg_logp = chosen_sum_logp / chosen_len  # -0.40
rejected_avg_logp = rejected_sum_logp / rejected_len  # -0.66

关键位置的数值

Sum Log Prob长度Avg Log Prob含义
Chosen (y_w)-12.832-0.40chosen 平均每个 token 的 log prob
Rejected (y_l)-18.528-0.66rejected 平均每个 token 的 log prob

直观理解

Step 4: Reference 前向传播

# Reference model(frozen,不训练)
with torch.no_grad():
    ref_outputs = reference(input_ids)
    ref_logits = ref_outputs.logits[:, :-1, :]  # [2, 39, 151936]
    ref_log_probs = F.log_softmax(ref_logits, dim=-1)  # [2, 39, 151936]
    ref_logp = ref_log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # [2, 39]
    ref_logp_masked = ref_logp * completion_mask  # [2, 39]

# Sum over sequence dimension
ref_sum_logp = ref_logp_masked.sum(dim=-1)  # [2]

# Split chosen and rejected
chosen_ref_sum_logp = ref_sum_logp[0]  # scalar, 比如 -14.2
rejected_ref_sum_logp = ref_sum_logp[1]  # scalar, 比如 -17.8

chosen_ref_avg_logp = chosen_ref_sum_logp / chosen_len  # -0.44
rejected_ref_avg_logp = rejected_ref_sum_logp / rejected_len  # -0.64

关键位置的数值(参考模型的结果):

Ref Sum Log ProbRef Avg Log Prob含义
Chosen (y_w)-14.2-0.44reference 对 chosen 的平均 log prob
Rejected (y_l)-17.8-0.64reference 对 rejected 的平均 log prob

为什么 reference 的 log prob 比当前 policy 低

Step 5: 计算 DPO 隐式 Reward —— Log Ratio

回顾 DPO 的隐式 reward 公式:

rθ(x,y)=βlogπθ(yx)πref(yx)r_\theta(x, y) = \beta \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}

这里的 logπθπref\log \frac{\pi_\theta}{\pi_{\text{ref}}} 就是 policy 和 reference 的 log prob 差值。

beta = 0.1  # KL 约束系数

# Log ratio(DPO 的隐式 reward)
# 注意:DPO 可以用 sum 或 average,这里用 average
chosen_log_ratio = chosen_avg_logp - chosen_ref_avg_logp  # -0.40 - (-0.44) = 0.04
rejected_log_ratio = rejected_avg_logp - rejected_ref_avg_logp  # -0.66 - (-0.64) = -0.02

# 隐式 reward(乘以 beta)
chosen_reward = beta * chosen_log_ratio  # 0.1 * 0.04 = 0.004
rejected_reward = beta * rejected_log_ratio  # 0.1 * (-0.02) = -0.002

关键位置的数值(β=0.1):

Policy Avg LogpRef Avg LogpLog Ratio (policy - ref)隐式 Reward (β × ratio)含义
Chosen (y_w)-0.40-0.44+0.04+0.004policy 比 ref 更”喜欢”chosen
Rejected (y_l)-0.66-0.64-0.02-0.002policy 比 ref 更”不喜欢”rejected

Log Ratio 的意义

Margin 计算

Step 6: DPO Loss 计算 —— 二分类 Loss

回顾 DPO loss 公式:

LDPO=E[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]
# Margin(隐式 reward 的差值)
margin = beta * (chosen_log_ratio - rejected_log_ratio)  # 0.006

# DPO loss(负的 log sigmoid)
loss = -F.logsigmoid(margin)  # scalar

# Backward
loss.backward()

数值计算

梯度更新方向

训练前后对比(假设训练一轮后):

训练前 Chosen训练前 Rejected训练后 Chosen训练后 Rejected改变
Policy Avg Logp-0.40-0.66-0.35-0.70chosen 提升,rejected 降低
Log Ratio+0.04-0.02+0.09-0.06margin 变大
隐式 Reward+0.004-0.002+0.009-0.006margin 变大

为什么训练后 margin 变大

SimPO:砍掉 Reference Model

DPO 虽然已经省掉了 reward model 和在线 RL,但仍然需要 reference model。每个 batch 要多做一遍前向传播,显存和速度都有开销。

SimPO 发现 DPO 有两个问题[3]

  1. Reward 与生成指标不一致:DPO 优化的是 logπθπref\log \frac{\pi_\theta}{\pi_{\text{ref}}}(log ratio),但实际生成时只关心 πθ\pi_\theta 自己的概率
  2. Reference model 开销:每个 sample 多做一遍前向

SimPO 的修改

修改 1:隐式 reward 改为 length-normalized average log prob

rSimPO(x,y)=βyt=1ylogπθ(ytx,y<t)r_{\text{SimPO}}(x, y) = \frac{\beta}{|y|} \sum_{t=1}^{|y|} \log \pi_\theta(y_t | x, y_{<t})

和 DPO 的对比:

除以长度 y|y| 是为了归一化,防止长回答天然概率低的问题。

修改 2:加入 target margin γ\gamma

LSimPO(πθ)=E(x,yw,yl)D[logσ(βywlogπθ(ywx)βyllogπθ(ylx)γ)]\mathcal{L}_{\text{SimPO}}(\pi_\theta) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma\left( \frac{\beta}{|y_w|} \log \pi_\theta(y_w|x) - \frac{\beta}{|y_l|} \log \pi_\theta(y_l|x) - \gamma \right) \right]

γ\gamma 是超参数(通常取 0.5-1.0),表示期望 chosen 和 rejected 之间的最小 margin。

DPO vs SimPO 对比

特性DPOSimPO
隐式 rewardβlogπθπref\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}$\frac{\beta}{
需要 reference✅ 需要❌ 不需要
长度归一化隐式(通过 ratio)显式(除以 |y|)
Target margin有(γ\gamma
显存开销2× policy1× policy
训练速度慢(2x forward)快(1x forward)

SimPO 在 AlpacaEval 2 上达到 72.4% 长度控制胜率,Arena-Hard 上 59.1%[3]。工程上,训练时间少 20%,峰值显存少 10%[3]

偏好数据构造

SFT 数据:(prompt, answer)。DPO/SimPO 需要:(prompt, chosen, rejected)

数据来源

人工标注(InstructGPT[1]):

自动化构造(SimPO[3]):

从线上 rollout 回流

什么样的样本适合做 SimPO

适合不适合
有明确好坏之分的(正确 vs 错误)都好或都差(难以区分)
格式规范可比较的格式混乱无法解析
同一 prompt 有多个候选只有一个回答

如果 ywy_wyly_l 差距不明显,模型学到的信号弱,训练效率低。

SimPO 训练过程拆解

SimPO 不需要 reference model,所以比 DPO 更简单。继续用上面的例子,演示 SimPO 的计算流程。

输入示例(与 DPO 相同):

流程概览

Step 1: Tokenize —— 分词与对齐

# Tokenize prompt, chosen, rejected
prompt = "请解释量子计算"
chosen = "量子计算利用量子比特进行计算,可以同时处理多个状态,计算能力远超经典计算机。"
rejected = "量子计算就是用量子做的计算,非常复杂,涉及很多物理原理。"

prompt_tokens = tokenizer(prompt)["input_ids"]  # [8]
chosen_tokens = tokenizer(prompt + chosen)["input_ids"]  # [40] = 8 + 32
rejected_tokens = tokenizer(prompt + rejected)["input_ids"]  # [36] = 8 + 28

prompt_len = len(prompt_tokens)  # 8
chosen_len = len(chosen_tokens)  # 40
rejected_len = len(rejected_tokens)  # 36

# Batch tensor for forward pass (with padding)
input_ids = tokenizer(
    [chosen, rejected],
    padding=True,
    return_tensors="pt"
)["input_ids"]  # [2, 40],rejected 会被 padding 到 40

Shape 解释

Step 2: Policy 前向传播与 Log Probs 提取

# Policy 前向传播(SimPO 只需要 policy,不需要 reference)
outputs = policy(input_ids)
logits = outputs.logits[:, :-1, :]  # [2, 39, 151936],预测下一个 token

# Log softmax
log_probs = F.log_softmax(logits, dim=-1)  # [2, 39, 151936]

# 提取实际 token 对应的 log prob
labels = input_ids[:, 1:]  # [2, 39],目标 token
token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # [2, 39]

# 创建 completion mask(只计算 response 部分,不计算 prompt)
completion_mask = torch.zeros_like(token_log_probs, dtype=torch.bool)
completion_mask[:, prompt_len-1:] = True  # [2, 39]

# Masked log probs(只保留 response 部分)
masked_log_probs = token_log_probs * completion_mask  # [2, 39]

Shape 解释

关键位置的数值(chosen 响应部分):

位置token ID中文对应概率token_log_prob含义
715234计算0.33-1.10模型有 33% 置信度预测”计算”是下一个 token
89123量子0.52-0.65模型有 52% 置信度预测”量子”是下一个 token
154532比特0.66-0.42模型有 66% 置信度预测”比特”是下一个 token
209876经典0.76-0.28模型有 76% 置信度预测”经典”是下一个 token
381516430.39-0.95模型有 39% 置信度预测结束

与 DPO 对比:这里的数值与 DPO 不同,因为 SimPO 用的是 policy 自己的概率,不依赖 reference。

Step 3: 计算 Response Token 级别的 Log Prob

# Sum over sequence dimension(只计算 response 部分)
sum_log_probs = masked_log_probs.sum(dim=-1)  # [2]

# Count completion tokens(注意 padding 的部分)
# 由于 rejected 被 padding 到 40,需要计算真实的 response 长度
chosen_response_len = chosen_len - prompt_len  # 32
rejected_response_len = rejected_len - prompt_len  # 28

# Split chosen and rejected
chosen_sum_logp = sum_log_probs[0]  # scalar, 比如 -10.5
rejected_sum_logp = sum_log_probs[1]  # scalar, 比如 -13.2

关键位置的数值

Sum Log ProbResponse 长度含义
Chosen (y_w)-10.532chosen 所有 response token 的 log prob 总和
Rejected (y_l)-13.228rejected 所有 response token 的 log prob 总和

为什么 chosen 的 sum 更高(-10.5 > -13.2)

Step 4: Length-normalized Average Log Prob —— SimPO 的核心

SimPO 的关键创新:强制使用长度归一化的平均 log prob 作为隐式 reward

DPO vs SimPO 的归一化对比:DPO 既可以用 log prob 的总和(sum),也可以用平均(average),因为 DPO 的 reward 是 log ratio,长度因素在相减时部分抵消了。SimPO 直接对 policy 的 log prob 做平均,显式地消除长度影响,这是 SimPO 能去掉 reference model 的关键设计之一。

rSimPO(x,y)=βyt=1ylogπθ(ytx,y<t)r_{\text{SimPO}}(x, y) = \frac{\beta}{|y|} \sum_{t=1}^{|y|} \log \pi_\theta(y_t | x, y_{<t})
beta = 2.0  # SimPO 的 beta 通常比 DPO 大(DPO 是 0.1,SimPO 是 2.0)

# Length-normalized average log prob
chosen_avg_logp = chosen_sum_logp / chosen_response_len  # -10.5 / 32 = -0.328
rejected_avg_logp = rejected_sum_logp / rejected_response_len  # -13.2 / 28 = -0.471

# SimPO 隐式 reward(长度归一化后的 avg log prob,乘以 beta)
chosen_reward_simpo = beta * chosen_avg_logp  # 2.0 * (-0.328) = -0.656
rejected_reward_simpo = beta * rejected_avg_logp  # 2.0 * (-0.471) = -0.942

关键位置的数值(β=2.0):

Sum Log ProbResponse 长度Avg Log ProbSimPO Reward (β × avg)含义
Chosen (y_w)-10.532-0.328-0.656chosen 的 SimPO reward
Rejected (y_l)-13.228-0.471-0.942rejected 的 SimPO reward

Avg Log Prob 的意义

为什么 SimPO 的 beta 比 DPO 大

Step 5: SimPO Loss 计算

回顾 SimPO loss 公式:

LSimPO(πθ)=E(x,yw,yl)D[logσ(βywlogπθ(ywx)βyllogπθ(ylx)γ)]\mathcal{L}_{\text{SimPO}}(\pi_\theta) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma\left( \frac{\beta}{|y_w|} \log \pi_\theta(y_w|x) - \frac{\beta}{|y_l|} \log \pi_\theta(y_l|x) - \gamma \right) \right]

其中 γ\gamma 是 target margin,表示期望 chosen 和 rejected 之间的最小差距。

SimPO 的 margin 定义:reward 差值减去 target margin,即 β(chosen_avgrejected_avg)γ\beta(\text{chosen\_avg} - \text{rejected\_avg}) - \gamma。当实际差距超过 target margin 时,margin > 0,loss 小;反之 loss 大,模型需要继续学习拉大差距。

gamma = 0.5  # target margin

# Margin(SimPO reward 的差值,减去 target margin)
margin = beta * (chosen_avg_logp - rejected_avg_logp) - gamma
# = 2.0 * (-0.328 - (-0.471)) - 0.5
# = 2.0 * 0.143 - 0.5
# = 0.286 - 0.5
# = -0.214

# SimPO loss(负的 log sigmoid)
loss = -F.logsigmoid(margin)  # scalar

# Backward
loss.backward()

数值计算

为什么 margin 是负数

训练效果(训练前后对比):

训练前 Chosen训练前 Rejected训练后 Chosen训练后 Rejected
Avg Log Prob-0.328-0.471-0.250-0.550
SimPO Reward-0.656-0.942-0.500-1.100
Margin (before γ)0.1430.300
Margin (after γ)-0.214+0.100
Loss0.8050.645

SimPO 与 DPO 的关键区别

特性DPOSimPO
Reward 公式βlogπθπref\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}$\frac{\beta}{
需要 reference✅ 需要❌ 不需要
Beta 典型值0.12.0
Reward 范围[-0.1, 0.1][-2, 0]
Target margin有(γ\gamma

为什么 SimPO 能不用 reference

为什么 SimPO 的 beta 更大

SimPO 代码实现

数据准备

# 假设有一对偏好数据
prompt = "请解释量子计算"
chosen = "量子计算利用量子比特进行计算,可以同时处理多个状态,计算能力远超经典计算机。"
rejected = "量子计算就是用量子做的计算,非常复杂,涉及很多物理原理。"

# Tokenize(拼接 prompt + response)
texts = [prompt + chosen, prompt + rejected]  # [2]
inputs = tokenizer(texts, return_tensors="pt", padding=True)
# input_ids: [2, 40](chosen 是 40,rejected 被 padding 到 40)
# attention_mask: [2, 40]

# 计算 prompt 和 response 的长度
prompt_tokens = tokenizer(prompt)["input_ids"]
prompt_len = len(prompt_tokens)  # 8
chosen_response_len = len(chosen_tokens) - prompt_len  # 32
rejected_response_len = len(rejected_tokens) - prompt_len  # 28examples/simpo_data.py

形状说明

前向传播与 log prob 提取

# Policy 前向传播(SimPO 只需要 policy,不需要 reference)
outputs = policy(**inputs)
logits = outputs.logits[:, :-1, :]  # [2, 39, V],预测下一个 token

# Log softmax to get log probs
log_probs = F.log_softmax(logits, dim=-1)  # [2, 39, V]

# 提取实际 token 对应的 log prob
labels = inputs.input_ids[:, 1:]  # [2, 39],目标 token
token_log_probs = log_probs.gather(
    dim=-1,
    index=labels.unsqueeze(-1)
).squeeze(-1)  # [2, 39]

# 创建 completion mask(只计算 response 部分)
completion_mask = torch.zeros_like(token_log_probs, dtype=torch.bool)
completion_mask[:, prompt_len-1:] = True  # [2, 39],从位置 7 开始

# Masked log probs(prompt 部分被 mask 为 0)
masked_log_probs = token_log_probs * completion_mask  # [2, 39]examples/simpo_forward.py

形状说明

Length-normalized average log prob

# Sum over sequence dimension(只计算 response 部分)
sum_log_probs = masked_log_probs.sum(dim=-1)  # [2]
# 比如:chosen_sum = -10.5, rejected_sum = -13.2

# Count completion tokens(真实长度,不含 padding)
response_lengths = torch.tensor([
    chosen_response_len,  # 32
    rejected_response_len,  # 28
])  # [2]

# Average log prob(length normalized)
avg_log_probs = sum_log_probs / response_lengths  # [2]
# 比如:chosen_avg = -0.328, rejected_avg = -0.471

# Split chosen and rejected
chosen_avg = avg_log_probs[0]    # scalar
rejected_avg = avg_log_probs[1]  # scalarexamples/simpo_avg_logp.py

形状说明

SimPO Loss

beta = 2.0  # SimPO 的 beta 通常比 DPO 大
gamma = 0.5  # target margin

# Margin with target(隐式 reward 的差值,减去 target margin)
margin = beta * (chosen_avg - rejected_avg) - gamma  # scalar
# = 2.0 * (-0.328 - (-0.471)) - 0.5
# = 2.0 * 0.143 - 0.5
# = 0.286 - 0.5
# = -0.214

# SimPO loss
loss = -F.logsigmoid(margin)  # scalar
# = -F.logsigmoid(-0.214)
# ≈ -log(0.447)
# ≈ 0.805

# Backward
loss.backward()examples/simpo_loss.py

形状说明

常见 bug

  1. 忘记 mask prompt:如果不 mask,avg_log_probs 会包含 prompt 部分,导致计算错误
  2. padding 长度计算错误:用 completion_mask.sum() 会包含 padding 部分,应该用真实长度
  3. beta 设置错误:SimPO 的 beta 通常比 DPO 大(2.0 vs 0.1),设置太小会导致训练不稳定
  4. gamma 设置错误:target margin 太大会导致训练困难,太小会让 margin 偏移中心

DPO 代码实现对比

DPO 的代码实现与 SimPO 类似,但需要额外计算 reference 的 log prob。

# Policy 前向传播
policy_outputs = policy(**inputs)
policy_logits = policy_outputs.logits[:, :-1, :]
policy_log_probs = F.log_softmax(policy_logits, dim=-1)
policy_logp = policy_log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
policy_logp_masked = policy_logp * completion_mask
policy_sum_logp = policy_logp_masked.sum(dim=-1)
policy_avg_logp = policy_sum_logp / response_lengths

# Reference 前向传播(frozen)
with torch.no_grad():
    ref_outputs = reference(**inputs)
    ref_logits = ref_outputs.logits[:, :-1, :]
    ref_log_probs = F.log_softmax(ref_logits, dim=-1)
    ref_logp = ref_log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    ref_logp_masked = ref_logp * completion_mask
    ref_sum_logp = ref_logp_masked.sum(dim=-1)
    ref_avg_logp = ref_sum_logp / response_lengths

# DPO Loss(需要 reference)
beta = 0.1  # DPO 的 beta 比较小
margin = beta * (policy_avg_logp[0] - ref_avg_logp[0]) - \
          beta * (policy_avg_logp[1] - ref_avg_logp[1])
# = 0.1 * (chosen_log_ratio - rejected_log_ratio)

loss = -F.logsigmoid(margin)
loss.backward()examples/dpo_loss.py

DPO vs SimPO 代码对比

特性DPOSimPO
前向传播次数2×(policy + reference)1×(只有 policy)
Beta0.12.0
Reward 计算βlogπθπref\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}βylogπθ\frac{\beta}{\mid y \mid} \log \pi_\theta
显存开销高(需要加载 reference)低(只需 policy)

小结:离线分支的演化

方法需要模型砍掉的组件用什么替代
PPOPolicy, Reference, RM, Critic在线 RL
DPOPolicy, ReferenceRM, Critic, PPOPolicy log ratio 作为隐式 reward
SimPOPolicyReference, RM, Critic, PPO长度归一化的 avg log prob

演化逻辑:

代价:


DPO 和 SimPO 的对比总结

RLHF 目标

maxθE[Rϕ(x,y)]βE[DKL(πθπref)]\max_\theta \, \mathbb{E}[R_\phi(x,y)] - \beta \mathbb{E}[D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}})]

PPO 完整 Objective

LCLIP+VF+S(θ)=E^t[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)]L^{CLIP+VF+S}(\theta) = \hat{\mathbb{E}}_t\left[ L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S[\pi_\theta](s_t) \right]

其中:

DPO Loss

LDPO=E[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\text{DPO}} = -\mathbb{E}\left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]

直觉:增大 chosen 的 log ratio,减小 rejected 的 log ratio

SimPO Loss

LSimPO=E[logσ(βywlogπθ(ywx)βyllogπθ(ylx)γ)]\mathcal{L}_{\text{SimPO}} = -\mathbb{E}\left[ \log \sigma\left( \frac{\beta}{|y_w|} \log \pi_\theta(y_w|x) - \frac{\beta}{|y_l|} \log \pi_\theta(y_l|x) - \gamma \right) \right]

直觉:不需要 reference,直接用 policy 自己的长度归一化 avg log prob

DPO vs SimPO 快速对比

特性DPOSimPO
隐式 rewardβlogπθπref\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}βylogπθ\frac{\beta}{\lvert y \rvert} \log \pi_\theta
需要 reference
Beta0.12.0
Target margin有(γ\gamma

下一篇回到在线分支:GRPO 用组采样替代 critic,DAPO 用四项修正把 GRPO 跑稳。



Previous Post
【九】在线RL:GRPO 与 DAPO 的推导与代码实现
Next Post
【七】基础RL:从策略梯度定理到 PPO 算法