Skip to content
传衡博客
返回

【七】基础RL:从策略梯度定理到 PPO 算法

参考资料
  1. Policy Gradient Methods for Reinforcement Learning with Function Approximation
  2. High-Dimensional Continuous Control Using Generalized Advantage Estimation
  3. Trust Region Policy Optimization
  4. Proximal Policy Optimization Algorithms
  5. Training language models to follow instructions with human feedback
  6. DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
  7. DAPO: An Open-Source LLM Reinforcement Learning System at Scale
  8. TRL PPOTrainer Documentation
  9. TRL PPOTrainer Source Code
  10. TRL PPOConfig Source Code
  11. TRL Utility Functions Source Code
  12. DeepSpeed-Chat RLHF README
  13. DeepSpeed-Chat PPO Trainer Source Code
  14. OpenRLHF README
  15. OpenRLHF PPO Ray CLI
  16. OpenRLHF Experience Maker Source Code
  17. TRL DPOConfig Source Code

PPO(Proximal Policy Optimization)是目前使用最广泛的强化学习算法之一。从游戏 AI 到机器人控制,再到 ChatGPT 的后训练,PPO 都是默认选择。

PPO 解决了传统策略梯度方法的核心问题:训练不稳定。REINFORCE 等算法的梯度方差极大,策略可能突然崩溃。PPO 通过限制每次更新的幅度,把策略变化控制在一个小范围内,训练过程平稳可控。

相比 TRPO 需要二阶优化,PPO 只用普通 SGD 就能达到类似效果,实现成本更低。无论是连续控制还是离散决策,PPO 都能直接套用。

在 LLM 领域,PPO 是 RLHF 的标配。InstructGPT 用 PPO 做后训练,1.3B 模型在人工偏好胜率上达到 73.4%,显著超过 SFT 基线的 65.8%[5]。DPO、SimPO、GRPO、DAPO 这些后续算法,本质上都是对 PPO 的某种简化或改进。

这篇文章从策略梯度定理开始,一步步推到 PPO。目标读者是有深度学习基础但无 RL 背景的人。每一步会给出:数学推导、直观解释、在 LLM 场景下的具体含义和数值示例。

LLM 场景下的 RL 符号

PPO 原本是通用 RL 算法,放到 LLM 上只需要重新解释符号:

RL 概念LLM 对应物例子
state sts_tprompt + 已生成 token”请解释牛顿定律” + “牛顿定律是…“
action ata_t下一个 token”…描述”
policy πθ(as)\pi_\theta(a\|s)模型输出的 softmax 分布P("描述"上文)P(\text{"描述"} \| \text{上文})
trajectory τ\tau完整 completion从 prompt 到 EOS 的整个序列
reward R(τ)R(\tau)RM 分数 + KL 惩罚Rϕ(x,y)βKLR_\phi(x,y) - \beta \cdot \text{KL}

LLM-RL 优化的是逐 token 的策略分布。每个位置都在做”给定上文,下一个 token 应该是什么”的决策,reward 只在序列末尾(或特定检查点)才显现出来。

策略梯度定理:为什么能对策略直接求导

问题设定

策略是一个带参数 θ\theta 的函数 πθ(as)\pi_\theta(a\|s),输出在动作空间上的概率分布。目标是最大化期望累积回报:

J(θ)=Eτpθ(τ)[R(τ)]J(\theta) = \mathbb{E}_{\tau \sim p_\theta(\tau)}[R(\tau)]

其中轨迹概率 pθ(τ)p_\theta(\tau) 是状态转移概率和策略的乘积:

pθ(τ)=p(s0)t=0Tπθ(atst)p(st+1st,at)p_\theta(\tau) = p(s_0) \prod_{t=0}^{T} \pi_\theta(a_t \| s_t) \, p(s_{t+1} \| s_t, a_t)

关键观察:环境转移与策略参数无关

状态转移概率 p(st+1st,at)p(s_{t+1} \| s_t, a_t) 是环境的固有属性,与 θ\theta 无关。只有 πθ(atst)\pi_\theta(a_t \| s_t) 依赖于参数。

logpθ(τ)\log p_\theta(\tau) 求梯度:

θlogpθ(τ)=θ(logp(s0)+t=0Tlogπθ(atst)+t=0Tlogp(st+1st,at))\nabla_\theta \log p_\theta(\tau) = \nabla_\theta \left( \log p(s_0) + \sum_{t=0}^{T} \log \pi_\theta(a_t \| s_t) + \sum_{t=0}^{T} \log p(s_{t+1} \| s_t, a_t) \right)

其中 logp(s0)\log p(s_0)logp(st+1st,at)\log p(s_{t+1} \| s_t, a_t)θ\theta 无关,梯度为零。只剩下:

θlogpθ(τ)=t=0Tθlogπθ(atst)\nabla_\theta \log p_\theta(\tau) = \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t \| s_t)

策略梯度定理的推导

J(θ)J(\theta) 求导,用 log-derivative trick(f=flogf\nabla f = f \cdot \nabla \log f):

θJ(θ)=θEτpθ[R(τ)]=θpθ(τ)R(τ)dτ=θpθ(τ)R(τ)dτ=pθ(τ)θlogpθ(τ)R(τ)dτ=Eτpθ[R(τ)θlogpθ(τ)]\begin{aligned} \nabla_\theta J(\theta) &= \nabla_\theta \mathbb{E}_{\tau \sim p_\theta}[R(\tau)] \\ &= \nabla_\theta \int p_\theta(\tau) R(\tau) \, d\tau \\ &= \int \nabla_\theta p_\theta(\tau) \, R(\tau) \, d\tau \\ &= \int p_\theta(\tau) \, \nabla_\theta \log p_\theta(\tau) \, R(\tau) \, d\tau \\ &= \mathbb{E}_{\tau \sim p_\theta}\left[ R(\tau) \, \nabla_\theta \log p_\theta(\tau) \right] \end{aligned}

代入上面的 θlogpθ(τ)\nabla_\theta \log p_\theta(\tau) 结果:

θJ(θ)=Eτpθ[R(τ)t=0Tθlogπθ(atst)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta}\left[ R(\tau) \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t \| s_t) \right]

这就是 Sutton 等人的策略梯度定理[1]

直观理解

公式 可以拆成两部分理解:

  1. θlogπθ(atst)\nabla_\theta \log \pi_\theta(a_t \| s_t):增大这个动作概率的方向
  2. R(τ)R(\tau):这个 scaling factor 决定往哪个方向走、走多远

如果 R(τ)>0R(\tau) > 0(轨迹总体是好的),梯度方向会让这条轨迹中出现过的动作概率增大。

如果 R(τ)<0R(\tau) < 0(轨迹总体是差的),梯度方向会让这些动作概率减小。

这就像”好动作多鼓励,坏动作多打压”。

为什么是 log 概率的梯度

直接对 πθ(as)\pi_\theta(a\|s) 求导,不同动作的概率尺度不同(比如常用词概率高,罕见词概率低),梯度大小不可比。取 log 后:

θlogπθ(as)=θπθ(as)πθ(as)\nabla_\theta \log \pi_\theta(a\|s) = \frac{\nabla_\theta \pi_\theta(a\|s)}{\pi_\theta(a\|s)}

这相当于相对变化率,消除了概率本身的尺度差异。

REINFORCE:策略梯度定理的直接实现

蒙特卡洛估计

公式 包含期望,实际计算时用采样估计:

θJ(θ)1Ni=1NR(τ(i))t=0Tθlogπθ(at(i)st(i))\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} R(\tau^{(i)}) \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t^{(i)} \| s_t^{(i)})

这就是 REINFORCE 算法:

  1. 用当前策略采 NN 条轨迹
  2. 计算每条轨迹的总回报 R(τ(i))R(\tau^{(i)})
  3. 按公式更新参数

REINFORCE 的问题:方差太大

R(τ)R(\tau) 是整条轨迹的累积回报,方差极大。数值示例:

轨迹累积回报 R(τ)R(\tau)动作”apple”出现梯度方向问题
A+1001 次正(鼓励)高回报可能来自运气
B-501 次负(打压)低回报可能来自噪音
C+800 次没用到这个信息

同一个动作”apple”,在 A 中被鼓励,在 B 中被打压。但 A 和 B 的回报差异(150 分)可能完全来自环境随机性,而不是动作本身的好坏。模型会在”鼓励 apple”和”打压 apple”之间剧烈抖动。

方差量化:假设回报服从 RN(50,1002)R \sim \mathcal{N}(50, 100^2),则梯度估计的方差与 1002=10000100^2 = 10000 成正比。Sutton 指出 REINFORCE 是无偏估计,但方差太大[1]。解决方案是引入 baseline。

Baseline 与 Advantage:降低方差

Baseline 的核心思想:不看绝对回报,看相对于平均水平的超额收益

Baseline 的数学原理

给每个动作加上一个与动作无关的 baseline b(st)b(s_t)

θJ(θ)=E[t=0Tθlogπθ(atst)(Q(st,at)b(st))]\nabla_\theta J(\theta) = \mathbb{E}\left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t \| s_t) \cdot (Q(s_t, a_t) - b(s_t)) \right]

为什么 baseline 不改变期望?

Eatπθ[θlogπθ(atst)b(st)]=b(st)Eat[θlogπθ(atst)]=0\mathbb{E}_{a_t \sim \pi_\theta}\left[ \nabla_\theta \log \pi_\theta(a_t \| s_t) \cdot b(s_t) \right] = b(s_t) \cdot \mathbb{E}_{a_t}\left[ \nabla_\theta \log \pi_\theta(a_t \| s_t) \right] = 0

因为:

Eat[θlogπθ(atst)]=aπθ(as)θπθ(as)πθ(as)=aθπθ(as)=θaπθ(as)=θ1=0\mathbb{E}_{a_t}\left[ \nabla_\theta \log \pi_\theta(a_t \| s_t) \right] = \sum_{a} \pi_\theta(a\|s) \cdot \frac{\nabla_\theta \pi_\theta(a\|s)}{\pi_\theta(a\|s)} = \sum_{a} \nabla_\theta \pi_\theta(a\|s) = \nabla_\theta \sum_{a} \pi_\theta(a\|s) = \nabla_\theta 1 = 0

概率分布之和为 1,梯度为 0。

常用 Baseline:状态价值函数

从直觉上,用状态的平均回报作为 baseline 很自然。给定 Q(st,at)Q(s_t, a_t) 是在状态 sts_t 采取动作 ata_t 后的期望回报,我们定义状态价值函数 V(st)V(s_t)——在这个状态下”平均来说”能得到多少回报:

V(st)=Eaπθ[Q(st,a)]V(s_t) = \mathbb{E}_{a \sim \pi_\theta}[Q(s_t, a)]

严格地说,最小方差 baseline 带有梯度权重;工程中通常用 V(s)V(s) 作为近似,简单且效果很好。

定义 Advantage(优势函数):

A(st,at)=Q(st,at)V(st)A(s_t, a_t) = Q(s_t, a_t) - V(s_t)

直观理解:不是看绝对回报,而是看”相对于平均水平的超额收益”。

数值示例

假设有三个轨迹在相同初始状态下采取不同动作:

轨迹总回报状态平均 V(s)V(s)Advantage含义
A+120100+20比平均水平好 20%,应该鼓励
B+80100-20比平均水平差 20%,应该打压
C+1001000平均水平,不更新

用 Advantage 后,只有轨迹 A 被鼓励、B 被打压。REINFORCE 用原始回报时,A 和 C 都被鼓励(因为都 >0),区分度不够。

为什么能降低方差

原始回报范围可能是 [-100, +1000],方差极大。减去 baseline 后,advantage 范围缩小到 [-50, +50],梯度的 variance 大幅降低,训练更稳定。

GAE:Generalized Advantage Estimation

实际中 QQVV 都需要估计。Schulman 等人提出 GAE[2],用 TD error(时序差分误差)的加权平均来估计 advantage。

定义 TD error(单步 Bellman residual):

δtV=rt+γV(st+1)V(st)\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)

如果 δtV>0\delta_t^V > 0,说明”实际得到的(即时奖励 + 下一状态价值)“比”当前估计的”要高,动作比预期好。

GAE 把多个 time step 的 TD error 加权求和:

A^tGAE(γ,λ)=l=0(γλ)lδt+lV\hat{A}^{\text{GAE}(\gamma,\lambda)}_t = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}^V

数值示例(设 γ=0.99\gamma=0.99, λ=0.95\lambda=0.95):

位置 ttrtr_tV(st)V(s_t)V(st+1)V(s_{t+1})δtV\delta_t^VGAE 贡献 (γλ)tδtV(\gamma\lambda)^t \delta_t^V
0+0.12.52.4+0.1 + 0.99×2.4 - 2.5 = +0.036+0.036
1+0.22.42.3+0.2 + 0.99×2.3 - 2.4 = +0.0770.94×0.077 = +0.072
2+0.02.32.2+0.0 + 0.99×2.2 - 2.3 = -0.1220.9420.94^2×(-0.122) = -0.108
3+2.02.20.0 (终止)+2.0 - 2.2 = -0.20.9430.94^3×(-0.2) = -0.166

位置 0 的 advantage:A^0=0.036+0.0720.1080.166=-0.166\hat{A}_0 = 0.036 + 0.072 - 0.108 - 0.166 = \textbf{-0.166}

λ=1\lambda=1 时,GAE 退化为 Monte Carlo:A^0=(0.1+0.2+0.0+2.0)2.5=0.2\hat{A}_0 = (0.1+0.2+0.0+2.0) - 2.5 = -0.2,与上表最后一行一致。

极端情况

PPO 论文默认用 GAE 估计 advantage[4]

重要性采样与 PPO

问题:REINFORCE 是 on-policy 的

策略梯度定理要求期望是在当前策略 pθp_\theta 下计算的。这意味着:

  1. 采一批数据
  2. 用这批数据算梯度、更新参数
  3. 参数变了,数据就”过期”了,必须扔掉重采

数据效率极低。

重要性采样:用旧策略的数据

设旧策略为 πθold\pi_{\theta_{\text{old}}},新策略为 πθ\pi_\theta。用旧策略采的数据,通过重要性权重修正,可以估计新策略的期望:

Eτpθ[f(τ)]=Eτpθold[pθ(τ)pθold(τ)f(τ)]\mathbb{E}_{\tau \sim p_\theta}[f(\tau)] = \mathbb{E}_{\tau \sim p_{\theta_{\text{old}}}}\left[ \frac{p_\theta(\tau)}{p_{\theta_{\text{old}}}(\tau)} f(\tau) \right]

对于单步动作,重要性比率为:

rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t \| s_t)}{\pi_{\theta_{\text{old}}}(a_t \| s_t)}

朴素的 surrogate objective

把重要性采样用到策略梯度上:

L(θ)=Et[rt(θ)A^t]L(\theta) = \mathbb{E}_t\left[ r_t(\theta) \cdot \hat{A}_t \right]

这看起来可以:用旧数据,乘以比率修正,就能更新新策略。

但问题是:如果新策略 πθ\pi_\theta 和旧策略 πθold\pi_{\theta_{\text{old}}} 差异太大,rt(θ)r_t(\theta) 会远离 1,估计器会变得不稳定。反复做多轮 SGD 后,ratio 会爆炸或归零。

TRPO 的思路:KL 约束

Schulman 在 TRPO 中的解决方案[3]:加约束,限制每次策略更新的幅度:

maxθEt[rt(θ)A^t]s.t.Et[DKL(πθoldπθ)]δ\max_\theta \, \mathbb{E}_t\left[ r_t(\theta) \hat{A}_t \right] \quad \text{s.t.} \quad \mathbb{E}_t\left[ D_{\text{KL}}(\pi_{\theta_{\text{old}}} \| \pi_\theta) \right] \leq \delta

这是一个约束优化问题,需要用二阶方法(共轭梯度)求解,实现复杂。

PPO 的 clip 近似

PPO 用一阶方法近似 TRPO 的约束。核心思想:不要让 ratio rt(θ)r_t(\theta) 偏离 1 太远

定义 clipped surrogate objective:

LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{\text{CLIP}}(\theta) = \mathbb{E}_t\left[ \min\left( r_t(\theta)\hat{A}_t, \, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t \right) \right]

其中 clip(r,1ϵ,1+ϵ)\text{clip}(r, 1-\epsilon, 1+\epsilon) 把 ratio 截断在 [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon] 区间内(通常 ϵ=0.2\epsilon = 0.2)。

Clip 的四种情况分析

公式 中的 min\min 操作会产生四种情况:

情况条件unclippedclipped最终取值效果
1A^t>0\hat{A}_t > 0rt>1+ϵr_t > 1+\epsilon很大(1+ϵ)A^t(1+\epsilon)\hat{A}_tclipped好动作不会被无限放大
2A^t>0\hat{A}_t > 0rt1+ϵr_t \leq 1+\epsilon正常\geq unclippedunclipped继续增大概率
3A^t<0\hat{A}_t < 0rt<1ϵr_t < 1-\epsilon很小(1ϵ)A^t(1-\epsilon)\hat{A}_tclipped坏动作不会被无限打压
4A^t<0\hat{A}_t < 0rt1ϵr_t \geq 1-\epsilon正常\leq unclippedunclipped继续减小概率

情况 1 的详细解释

情况 3 是对称的:对于坏动作,防止过度打压导致概率归零。

PPO 在 LLM 上的四模型架构

InstructGPT 把 PPO 用于语言模型后训练,流程是:SFT → 训练 reward model → PPO 微调[5]

RL 阶段需要 4 个模型:

模型符号作用是否训练
Policyπθ\pi_\theta生成回答的模型✅ 训练
Referenceπref\pi_{\text{ref}}SFT 模型的快照,用于 KL 约束❌ 冻结
Reward ModelRϕR_\phi给 completion 打分❌ 冻结
Critic / ValueVψV_\psi估计每个 token 位置的 value✅ 训练

RLHF 目标函数

InstructGPT 的 RL 目标[5]

maxθE(x,y)πθ[Rϕ(x,y)βlogπθ(yx)πref(yx)]\max_\theta \, \mathbb{E}_{(x,y) \sim \pi_\theta} \left[ R_\phi(x,y) - \beta \log \frac{\pi_\theta(y \| x)}{\pi_{\text{ref}}(y \| x)} \right]

两项分别表示:

  1. Rϕ(x,y)R_\phi(x,y):reward model 对生成质量的整体评分
  2. βlogπθπref-\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}}:KL 惩罚,防止 policy 偏离 reference 太远

KL 惩罚的直观理解

logπθπref\log \frac{\pi_\theta}{\pi_{\text{ref}}} 衡量新策略偏离 reference 的程度。

数值示例(设 β=0.1\beta=0.1):

回答policy 概率ref 概率logπθπref\log \frac{\pi_\theta}{\pi_{\text{ref}}}KL 惩罚
回答 A0.50.3log(1.67)=+0.51\log(1.67) = +0.51-0.051
回答 B0.20.4log(0.5)=0.69\log(0.5) = -0.69+0.069

回答 A 的概率从 30% 提升到 50%,被惩罚 -0.051;回答 B 的概率从 40% 降到 20%,被奖励 +0.069。整体效果是防止 policy 剧烈偏离 reference。

Reward Model 训练:从人类偏好到打分模型

Rϕ(x,y)R_\phi(x,y) 是一个神经网络,输出回答的评分。它不是”环境直接给出的标量”,而是从人类偏好数据训练出来的。

数据收集:

收集大量 prompt 对应的两个回答,让人类标注”更喜欢哪个”:

Prompt: "写一个冒泡排序的 Python 实现"

Answer A: "用两个循环嵌套..."
Answer B: "用冒泡排序,时间复杂度 O(n²)..."
人类偏好: A (更简洁)

模型训练:

用偏好对训练 RM,让它学会预测”人类更偏好哪个”。这本质上是二分类排序任务

LBT=E[logσ(ywyl)]\mathcal{L}_{\text{BT}} = -\mathbb{E}\left[ \log \sigma(y_w - y_l) \right]

其中 ywy_wyly_l 分别是人类偏好的和偏差的回答,σ\sigma 是 sigmoid 函数。

为什么不用规则?

人类偏好很复杂:准确、有用性、安全性、风格等。规则难以穷举所有情况,训练模型可以学会这种”直觉”。

PPO 训练过程拆解

假设我们有一条 prompt,模型生成了 response。用具体数值演示整个训练循环。

输入示例

整体流程

阶段步骤做什么输出
Rollout1-5生成响应,计算 log probs、KL、rewards、returnsreturns, old_logp, values
Critic 训练6用 returns 作为 target,训练 criticvalue_loss
Policy 训练7-8用 advantage,训练 policy 优化决策policy_loss

Rollout 与数据准备

Step 1: Rollout —— 生成响应

# Tokenize
prompt_tokens = tokenizer("请解释牛顿第一定律")["input_ids"]  # shape: [8]
prompt_len = len(prompt_tokens)  # 8

# Generate
full_ids = policy.generate(prompt_tokens, max_new_tokens=25)  # shape: [33]
response_ids = full_ids[prompt_len:]  # shape: [25]

张量形状

变量形状含义
prompt_tokens[8]prompt 有 8 个 token
full_ids[33]完整序列 = prompt(8) + response(25)
response_ids[25]只取生成的 25 个 token

关键位置的 token

位置token中文说明
08712牛顿response 开头
35421定律中间位置
24151643序列结束符

Step 2: 提取 Log Probs —— 从 logits 到标量

TRL 将这一步封装成 selective_log_softmax 函数,具体写法如下。[11]

# Policy 输出 logits,形状是 [batch=1, seq_len=25, vocab=151936]
logits = policy(input_ids).logits[:, prompt_len-1:-1, :]  # [1, 25, 151936]

# 取 log_softmax
log_probs = F.log_softmax(logits, dim=-1)  # [1, 25, 151936]

# 提取实际 token 对应的 log prob
logp = log_probs.gather(-1, response_ids.unsqueeze(-1)).squeeze(-1)  # [1, 25]examples/log_probs.py

Shape 解释

关键位置的数值

位置tokenlogp对应概率含义
0牛顿-1.230.29模型有 29% 置信度预测”牛顿”
3定律-0.450.64模型有 64% 置信度预测”定律”
24-2.100.12模型有 12% 置信度预测结束

logp 的意义:对数概率。logp 越接近 0,模型越”确信”这个 token 是正确的。指数后得到概率:exp(-0.45) ≈ 0.64

logp vs 概率

特性概率logp
范围[0, 1][-∞, 0]
连乘0.1×0.1×0.1=0.001(易下溢)(-2.3)+(-2.3)+(-2.3)=-6.9(稳定)
梯度小概率区域梯度消失对数空间梯度均匀

LLM 的词表通常有 10 万+ token,单次预测概率常在 0.001-0.1 之间。25 个 token 的序列概率连乘会变成 105010^{-50} 量级,float32 直接下溢成 0。取 log 后变成求和,数值稳定。

为什么公式用 logp:策略梯度定理 logπ\nabla \log \pi、KL 散度定义,自然形式就是基于 logp 的。

Step 3: 计算策略偏差 —— policy 离 reference 有多远

# Reference model(frozen,不训练)
with torch.no_grad():
    ref_logits = reference(input_ids).logits[:, prompt_len-1:-1, :]  # [1, 25, 151936]
    ref_log_probs = F.log_softmax(ref_logits, dim=-1)
    ref_logp = ref_log_probs.gather(-1, response_ids.unsqueeze(-1)).squeeze(-1)  # [1, 25]

# Log-ratio(KL 散度的 Monte Carlo 估计项)
# 真正的 KL 散度是所有可能 token 的期望,恒非负
# 这里是采样到的单个 token 的 log-ratio,可以为负
kl = logp - ref_logp  # [1, 25]  # 即 log(π/π_ref)examples/ref_logp_and_kl.py

关键位置的数值(设 β=0.1):

位置logpref_logplog_ratio (kl)含义
0-1.23-1.10-0.13policy 比 ref 更”保守”
3-0.45-0.38-0.07policy 比 ref 稍微更保守
24-2.10-2.25+0.15policy 比 ref 更”激进”

Log-ratio 的意义

这是 KL 散度的 Monte Carlo 估计:DKL(ππref)E[logππref]D_{\text{KL}}(\pi \| \pi_{\text{ref}}) \approx \mathbb{E}[\log \frac{\pi}{\pi_{\text{ref}}}]。对单个采样 token,这个值可以取负,但对整个分布的期望恒非负。

公式 中的第二项 βlogπθπref=βlog_ratio-\beta \log \frac{\pi_\theta}{\pi_{\text{ref}}} = -\beta \cdot \text{log\_ratio} 就是在做这件事:当 log_ratio 为正(policy 概率比 ref 高),惩罚为负(打压过高概率);当 log_ratio 为负(policy 概率比 ref 低),惩罚为正(鼓励提高概率)。

Step 4: 合成 Token Reward —— log-ratio 惩罚 + 末尾 RM 分数

# Token-level reward:每个位置都是负的 log-ratio 惩罚(KL 约束项)
token_reward = -beta * kl  # [1, 25]

# Sequence-level reward:只在末尾加上 RM 分数
rm_score = reward_model(full_ids).scores  # [1], 比如 +2.0
token_reward[0, -1] += rm_score  # 在 EOS 位置加上 RM 分数examples/token_reward.py

关键位置的数值(β=0.1,RM score=+2.0):

位置log_ratio-β·log_ratioRM 加成token_reward
0-0.13+0.0130+0.013
3-0.07+0.0070+0.007
24+0.15-0.015+2.0+1.985

reward 的结构

Step 5: 计算 Returns —— 从后往前累加

Return 是”从当前位置开始到序列结束的累积回报”。用反向累积求和:

# 从后往前累加(假设 discount γ=1,LLM 通常不做 discount)
returns = torch.cumsum(token_reward.flip(dims=[1]), dim=1).flip(dims=[1])  # [1, 25]

关键位置的数值

位置token_rewardreturn含义
0+0.013+1.995从位置 0 开始,未来能获得的总回报
3+0.007+1.982从位置 3 开始,未来能获得的总回报
24+1.985+1.985最后一个位置的回报就是它自己

计算过程(从后往前):

return 的意义:告诉 critic “在位置 tt 做决策,未来大概能拿多少分”。


Critic 训练

Step 6: 训练 Critic —— 用 MSE loss 预测 returns

这一步训练 critic,让它学会预测从每个位置开始的累积回报。

# Critic 输出每个位置的 value 预测
values = critic(input_ids).scores[:, prompt_len:]  # [1, 25]

# Critic 训练目标:让预测的 values 接近实际的 returns
value_loss = F.mse_loss(values, returns)

关键位置的数值(critic 预测的 values):

位置returncritic 预测误差
0+1.995+1.800+0.195
3+1.982+1.950+0.032
24+1.985+2.100-0.115

Critic 的作用:学会预测 V(st)V(s_t),即在状态 sts_t 下”未来能拿多少分的期望值”。训练好后,critic 的预测值就是 baseline,用来计算 advantage。


Policy 训练

Step 7: 计算 Advantage —— 连接两个模型的桥梁

Critic 训练好后,用它来计算 advantage。

# Advantage 计算(两种等效写法)
# 方法1:returns - values(蒙特卡洛,等价于 GAE with λ=1)
adv = returns - values

# 方法2:GAE(实际更常用,用 weighted sum of TD errors)
# adv = gae_estimation(td_errors, gamma=0.99, lam=0.95)

# Whiten:标准化到均值 0、标准差 1
adv = (adv - adv.mean()) / (adv.std() + 1e-8)

两种方法的关系:

方法公式特点典型场景
Returns - ValuesA^t=RtV(st)\hat{A}_t = R_t - V(s_t)无偏、方差大简单实现、稀疏 reward
GAEA^t=l(γλ)lδt+l\hat{A}_t = \sum_l (\gamma\lambda)^l \delta_{t+l}有偏、方差小PPO 论文推荐、连续控制

PPO 原始论文明确推荐使用 GAE(λ=0.95\lambda=0.95[4]。TRL 官方 PPOTrainer 的默认实现也是 GAE:PPOConfig.lam 默认值为 0.95,在 compute_advantages 函数里使用 lastgaelam 递推 advantage,再用 returns = advantages + values 得到 return[8][9][10]

这段实现基本就是标准 GAE:

lastgaelam = 0
advantages_reversed = []
for t in reversed(range(gen_length)):
    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
    lastgaelam = delta + args.gamma * args.lam * lastgaelam
    advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + valuestrl/experimental/ppo/ppo_trainer.py

在 LLM 场景里,由于 reward 往往集中在序列末尾,很多教程和自定义实现会直接用 returns - values 代替标准的 GAE。下文的数值示例也使用这个更直观的写法。

关键位置的数值(用 returns - values):

位置returncritic 预测raw_advwhiten 后 adv
0+1.995+1.800+0.195+1.32
3+1.982+1.950+0.032+0.21
24+1.985+2.100-0.115-0.77

advantage 的意义

为什么要 whiten:不同 batch 的 return 尺度可能差异很大。标准化后,advantage 分布稳定,PPO clip 的 ±ε 截断才有意义。

Step 8: 训练 Policy —— PPO Clip Loss

用 advantage 作为权重,鼓励好动作、打压坏动作。核心:importance sampling ratio rtr_t 乘 advantage,但用 clip 限制 rtr_t 的变化幅度。

# 用更新后的 policy 重新算 log prob
new_logp = F.log_softmax(policy(input_ids).logits[:, prompt_len-1:-1, :], dim=-1)
new_logp = new_logp.gather(-1, response_ids.unsqueeze(-1)).squeeze(-1)  # [1, 25]

# Importance sampling ratio
ratio = (new_logp - old_logp).exp()  # [1, 25]

# Clip
clipped_ratio = ratio.clamp(1 - eps, 1 + eps)  # ε=0.2

# Policy loss(注意取负号,因为我们要最大化)
policy_loss = -torch.min(ratio * adv, clipped_ratio * adv).mean()examples/ppo_policy_loss.py

关键位置的数值(ε=0.2):

位置old_logpnew_logpratioclipped_ratioadvunclipped 贡献clipped 贡献min 取值
0-1.23-1.151.091.09+1.32+1.44+1.44+1.44
3-0.45-0.221.251.20+0.21+0.26+0.25+0.25
24-2.10-1.801.351.20-0.77-1.04-0.92-1.04

位置 0(好动作,适度优化)

位置 3(好动作,过度优化被限制)

四种情况汇总

场景ratioclippedadv效果
好动作,超上限1.251.20+限制过度鼓励
好动作,正常1.151.15+正常鼓励
坏动作,超下限0.700.80-限制过度打压
坏动作,正常0.850.85-正常打压

核心目的:防止策略在一次更新中变化太大,保持训练稳定性。

PPO 的 Reward 设计

InstructGPT 用单个 reward model。DeepSeekMath 和 DAPO 在数学任务中用 rule-based verifier[6][7]

实际场景 reward 可以拆成多个维度:

进入 PPO 前必须合成标量:

Rtotal=w1Rcorrect+w2Rformat+w3Rtool+w4RlengthR_{\text{total}} = w_1 \cdot R_{\text{correct}} + w_2 \cdot R_{\text{format}} + w_3 \cdot R_{\text{tool}} + w_4 \cdot R_{\text{length}}

不同维度量纲不同,需要先做归一化或缩放。

代码实现

用 Qwen3-8B 为例,展示 PPO-RLHF 的关键代码。

下面 3 段代码按 TRL PPOTrainer 的 rollout / reward / loss 流程进行展示,省略 padding mask、truncate response 和分布式训练的工程细节[9][11]

Rollout 阶段

# prompts: list of strings
batch = tokenizer(prompts, return_tensors="pt", padding=True)
generated = policy.generate(**batch, max_new_tokens=256)

prompt_len = batch["input_ids"].shape[1]
response_ids = generated[:, prompt_len:]                 # [B, T_resp]

# Policy forward for log probs
outputs = policy(generated)
logits = outputs.logits[:, prompt_len - 1 : -1]         # [B, T_resp, V]

# Reference forward (no grad)
with torch.no_grad():
    ref_outputs = reference(generated)
    ref_logits = ref_outputs.logits[:, prompt_len - 1 : -1]  # [B, T_resp, V]

# Critic value estimation
values = critic(generated).scores[:, prompt_len:]       # [B, T_resp]

# Reward model score (sequence level)
rm_score = reward_model(generated).scores.squeeze(-1)   # [B]examples/ppo_rollout.py

Reward 计算

# Extract token log probs
def token_logps(logits, ids):
    log_probs = F.log_softmax(logits, dim=-1)
    return log_probs.gather(-1, ids.unsqueeze(-1)).squeeze(-1)

logp = token_logps(logits, response_ids)                # [B, T_resp]
ref_logp = token_logps(ref_logits, response_ids)        # [B, T_resp]
kl = logp - ref_logp                                    # [B, T_resp]

# Token-level reward: negative KL + sequence reward at the end
token_reward = -beta * kl                               # [B, T_resp]
token_reward[:, -1] += rm_score                         # Add sequence reward at EOS

# Compute returns (reverse cumulative sum)
returns = reverse_cumsum(token_reward)                  # [B, T_resp]

# Advantage
adv = returns - values                                  # [B, T_resp]
adv = whiten(adv)  # Normalize to mean 0, std 1examples/ppo_rewards.py

PPO Loss

# Recompute log probs with new policy
new_logits = policy(generated).logits[:, prompt_len - 1 : -1]
new_logp = token_logps(new_logits, response_ids)        # [B, T_resp]

old_logp = logp.detach()                                # [B, T_resp]

# Importance sampling ratio
ratio = (new_logp - old_logp).exp()                     # [B, T_resp]

# Clip
clipped_ratio = ratio.clamp(1 - eps, 1 + eps)           # [B, T_resp]

# Policy loss (negative because we minimize)
policy_loss = -torch.min(
    ratio * adv,
    clipped_ratio * adv
).mean()

# Value loss
value_loss = F.mse_loss(values, returns)

# Total loss
loss = policy_loss + c1 * value_lossexamples/ppo_loss.py

教科书 PPO vs 工程版 PPO

教科书 PPO[3]严格按照论文实现,特点是简单直接但样本效率低。

严格 on-policy:采一批数据,更新几轮,然后扔掉重新采。PPO 原始论文是 on-policy 算法,期望是在当前策略分布下计算梯度。如果策略参数更新太多,rollout 的数据就不再代表新策略的分布,继续用旧数据训练会产生偏差。因此教科书做法是每轮重新 rollout,虽然样本利用率低但理论保证正确。

工程版 PPO(如 InstructGPT)在实际部署时对”纯 PPO”做了多项工程优化:

Pretraining mix:在 RLHF 数据中混入 SFT 数据,防止模型只针对 RM 评分优化而丧失通用能力。InstructGPT 发现,如果不混入 SFT 数据,模型会在 RM 任务上表现很好,但其他任务的性能会下降[5]

Rollout buffer:近端数据回放,重复使用之前采集但未充分训练的数据。纯 PPO 每轮都要重新 rollout,样本利用率低;工程版会把最近 N 轮的 rollout 存下来,在多轮 PPO 更新中复用。

PPO 的工程优化 Reference Log Prob 缓存

PPO 的 response 是在线生成的,每轮 rollout 都不一样,所以 reference 的 log prob 可以通过按 batch 缓存的方式,避免重复计算。

Batch 级缓存的实现

TRL 在 rollout 阶段一次性算出 ref_logprobs,和 logprobsvaluesrewards 一起存进 buffer。在进行 PPO 训练的时候只读取缓存值做更新,reference 模型不再对缓存的数据做前向[9]

DeepSpeed-Chat 的做法为:rollout 返回字典里直接带上 ref_logprobs 字段,训练时从 inputs["ref_logprobs"] 读取缓存的 ref_logprobs [12]

return {
    "logprobs": gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
    "ref_logprobs": gather_log_probs(logits_ref[:, :-1, :], seq[:, 1:]),
    "value": values,
    "rewards": reward_score,
}applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py

缓存的局限

每轮新 rollout 开始时,reference 仍然要前向一次,给新生成的 response 计算 KL Baseline。真正省掉的是重复前向的消耗。

工程上的优化方向是降低这单次前向的开销。DeepSpeed-Chat 支持 reference CPU offload;OpenRLHF 提供 --colocate_actor_ref--ref_reward_offload,在 experience_maker 里单独调用 initial_model_group 时生成 base_action_log_probs_ref[14][16]

与 DPO 的区别

DPO 的样本是静态的,可以预计算全量 reference log prob。TRL 的 DPOConfig 提供 precompute_ref_log_probs=True,训练时完全不需要保留 ref 模型[17]。PPO 的 rollout 每轮变化,只能用 batch 级缓存配合 offload 来降低开销。

难例重采样:按 reward 分桶,对低 reward 样本进行过采样。RM 给低分的样本说明模型当前有短板,多采样这些”困难样本”可以加速改进。通常按 reward 分位数(如 bottom 20%)多采样 2-3 倍。

无效样本过滤:直接丢弃明显无效的 rollout。包括:生成过长(如超过 2048 token)、大量重复内容、格式错误(如 JSON 解析失败)的样本。这些样本计算开销大且没有训练价值,提前过滤可以节省计算资源。

这些变体让实现偏离”纯 PPO”,但提升了样本效率和训练稳定性。

小结:演化树预览

PPO 的四模型架构是起点,后续算法都在砍组件:

算法PolicyReferenceReward ModelCritic核心改动
PPO基线
DPO消掉在线 RL,变成离线偏好学习
SimPO再消掉 reference
GRPO用组采样替代 critic
DAPOGRPO + 四项工程修正

下一篇从 RLHF 目标出发,推导 DPO 如何把四模型压缩成二分类 loss。



Previous Post
【八】离线RL:DPO 与 SimPO 的推导与代码实现
Next Post
【六】长上下文 SFT 与双卡 BranchParallel + SimPO 代码实现