训练大模型时最崩溃的时刻,莫过于看到 RuntimeError: CUDA out of memory。Activation Checkpointing(激活检查点,也叫梯度检查点)能把显存占用降低 50%~70%,代价是训练时间增加约 20%。这篇文章会用一个简单的类比帮你建立直觉,然后手把手教你用起来。
参考资料
显存都去哪了?
训练神经网络时,GPU 显存主要被四部分占用:
| 组成部分 | 1B 模型示例 | 说明 |
|---|---|---|
| 模型参数 | 4 GB | 存储权重 |
| 梯度 | 4 GB | 反向传播计算得到 |
| 优化器状态 | 8 GB | Adam 需要维护动量等状态 |
| 激活值 | 32 GB | 前向计算的中间结果 |
| 合计 | 48 GB | batch_size=4, seq_len=2048 |
激活值是最大的显存消耗者。对于 Transformer 模型,激活值显存与层数、序列长度、batch size 成正比。当模型变大或序列变长时,激活值会迅速吃掉所有显存。
草稿纸类比
想象你在做一道很长的数学题,需要多步计算。
正常训练的做法:每一步计算都把中间结果写在草稿纸上。这样回头检查时,所有中间结果都随手可得。问题是,草稿纸(显存)很快就用完了。
Activation Checkpointing 的做法:只在关键步骤(检查点)保存中间结果,其他步骤需要时重新计算。代价是多花点时间重算,但草稿纸用量大幅减少。
具体到神经网络:
- 前向传播:只保存检查点层的输出,丢弃其他层的激活值
- 反向传播:需要某个层的激活值时,从最近的检查点重新前向计算到该层
流程对比:ASCII 图示
假设有一个 4 层的网络(Layer 1 → Layer 2 → Layer 3 → Layer 4):
正常训练
前向传播:
Input → [L1] → [L2] → [L3] → [L4] → Loss
↓ ↓ ↓ ↓
保存 保存 保存 保存 ← 所有激活值都存
反向传播:
Loss ← [L4] ← [L3] ← [L2] ← [L1] ← Grad
↑ ↑ ↑ ↑
读取 读取 读取 读取 ← 直接使用保存的激活值
显存占用:O(L) 随层数线性增长
Activation Checkpointing(以 Layer 2 为检查点)
前向传播:
Input → [L1] → [L2] → [L3] → [L4] → Loss
↓ ✓ ↓ ↓
丢弃 保存 丢弃 丢弃 ← 只保存检查点
反向传播:
Loss ← [L4] ← [L3] ← [L2] ← [L1] ← Grad
↑ ↑
重算 ← 从检查点重新前向到 L3
↑ ↑
重算 ← 从 Input 重新前向到 L1
显存占用:O(1) 常数级别(只存检查点)
PyTorch 实战
PyTorch 提供了 torch.utils.checkpoint 来实现这一功能。
基础用法:一个简单的 MLP
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(1000, 1000)
self.layer2 = nn.Linear(1000, 1000)
self.layer3 = nn.Linear(1000, 10)
self.relu = nn.ReLU()
def forward(self, x):
# 普通前向:直接调用
# x = self.relu(self.layer1(x))
# x = self.relu(self.layer2(x))
# Checkpoint 版本:用 checkpoint 包装
x = checkpoint(self._forward_block, x, use_reentrant=False)
return x
def _forward_block(self, x):
x = self.relu(self.layer1(x))
x = self.relu(self.layer2(x))
x = self.layer3(x)
return x
# 显存对比测试
def test_memory():
model = SimpleMLP().cuda()
x = torch.randn(32, 1000, requires_grad=True).cuda()
# 普通训练
torch.cuda.reset_peak_memory_stats()
out = model._forward_block(x)
out.sum().backward()
mem_normal = torch.cuda.max_memory_allocated() / 1024**2
# Checkpoint 训练
model.zero_grad(set_to_none=True)
torch.cuda.reset_peak_memory_stats()
out = model(x)
out.sum().backward()
mem_checkpoint = torch.cuda.max_memory_allocated() / 1024**2
print(f"普通模式显存: {mem_normal:.1f} MB")
print(f"Checkpoint 模式显存: {mem_checkpoint:.1f} MB")
print(f"节省: {(1 - mem_checkpoint/mem_normal)*100:.1f}%")
test_memory()examples/basic_checkpoint.py
运行结果示例:
普通模式显存: 256.0 MB
Checkpoint 模式显存: 128.0 MB
节省: 50.0%
在 Transformer 中的应用
实际使用中,通常对整个 Transformer Block 应用 checkpoint:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# 自注意力
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# 前馈网络
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class CheckpointedTransformer(nn.Module):
def __init__(self, n_layers, d_model, n_heads):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads)
for _ in range(n_layers)
])
def forward(self, x):
for i, block in enumerate(self.blocks):
# 每层都使用 checkpoint
x = checkpoint(
block,
x,
use_reentrant=False, # 推荐设为 False,避免一些限制
preserve_rng_state=False # 如果不需要 dropout 随机性
)
return xexamples/transformer_checkpoint.py
使用注意事项
| 情况 | 建议 |
|---|---|
use_reentrant=False | 推荐。支持梯度缩放、更少的限制 |
| 随机性(dropout) | 需要确定性时设置 preserve_rng_state=True |
| 检查点粒度 | 太细(每层都 checkpoint)会增加重算开销;太粗(整个模型)节省显存少 |
| 哪些层适合 checkpoint | 参数量小、计算快的层(如 Attention、FFN) |
| 哪些层不适合 | 参数量大的层(如嵌入层)、输入输出层 |
主流框架实现深度对比
不同深度学习框架对 Activation Checkpointing 的实现各有特色。本节深入剖析 DeepSpeed、Megatron-LM 和 Unsloth 三种主流框架的实现机制,帮助你根据实际需求做出选择。
5.1 实现架构概览
三种框架在 PyTorch 原生 checkpoint 基础上,针对大规模训练场景进行了不同程度的扩展:
┌─────────────────────────────────────────────────────────────────────────┐
│ 框架架构对比图 │
├─────────────────┬─────────────────┬─────────────────┬─────────────────────┤
│ PyTorch 原生 │ DeepSpeed │ Megatron-LM │ Unsloth │
├─────────────────┼─────────────────┼─────────────────┼─────────────────────┤
│ │ ┌───────────┐ │ ┌───────────┐ │ ┌───────────┐ │
│ checkpoint() │ │ 与 ZeRO │ │ │ Tensor/ │ │ │ 内存池 │ │
│ ↓ │ │ 深度集成 │ │ │ Pipeline │ │ │ 异步卸载 │ │
│ Checkpoint │ │ 协同优化 │ │ │ Parallel │ │ │ 融合算子 │ │
│ Function │ └───────────┘ │ │ 协同 │ │ └───────────┘ │
│ │ │ └───────────┘ │ │
│ • 基础功能 │ • 分区激活值 │ • 三种粒度 │ • 30% 额外节省 │
│ • 灵活可控 │ • CPU 卸载 │ • 层间并行 │ • 超长上下文 │
│ • 简单直接 │ • 连续内存 │ • 超大规模 │ • 易用性强 │
└─────────────────┴─────────────────┴─────────────────┴─────────────────────┘
设计哲学差异:
- PyTorch 原生:提供最基础的构建块,灵活性最高
- DeepSpeed:与 ZeRO 优化器深度协同,追求极致显存效率
- Megatron-LM:为千亿参数级 Transformer 设计,重视并行效率
- Unsloth:追求开箱即用的极致优化,降低使用门槛
5.2 DeepSpeed 详解
DeepSpeed 的 Activation Checkpointing 位于 deepspeed.runtime.activation_checkpointing.checkpointing 模块,核心类包括 CheckpointFunction 和 CudaRNGStatesTracker。
5.2.1 与 ZeRO 的协同机制
DeepSpeed 的独特优势在于 checkpoint 与 ZeRO(Zero Redundancy Optimizer)的深度集成:
┌─────────────────────────────────────────────────────────────────────────┐
│ DeepSpeed ZeRO + Checkpoint 协同架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │
│ │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌──────────┐ │ │
│ │ │ ZeRO-2/3 │ │ │ │ ZeRO-2/3 │ │ │ │ ZeRO-2/3 │ │ 优化器 │
│ │ │ 参数分区 │ │ │ │ 参数分区 │ │ │ │ 参数分区 │ │ 状态分区 │
│ │ └──────────┘ │ │ └──────────┘ │ │ └──────────┘ │ │
│ │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌──────────┐ │ │
│ │ │ 激活分区 │ │◄────┼►│ 激活分区 │ │◄────┼►│ 激活分区 │ │ Checkpoint │
│ │ │ (1/N) │ │ │ │ (1/N) │ │ │ │ (1/N) │ │ 分区激活 │
│ │ └──────────┘ │ │ └──────────┘ │ │ └──────────┘ │ │
│ │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌──────────┐ │ │
│ │ │ CPU 缓冲区│ │ │ │ CPU 缓冲区│ │ │ │ CPU 缓冲区│ │ CPU │
│ │ │(Pinned) │ │ │ │(Pinned) │ │ │ │(Pinned) │ │ Offload │
│ │ └──────────┘ │ │ └──────────┘ │ │ └──────────┘ │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │
│ 总显存占用 ≈ 模型参数/N + 优化器状态/N + 激活值/N + 梯度/N │
│ │
└─────────────────────────────────────────────────────────────────────────┘
关键特性详解:
-
partition_activations:跨 GPU 分区激活值
- 使用模型并行通信原语(allgather, reduce_scatter)
- 每个 GPU 只存储 1/N 的激活值(N 为 GPU 数量)
- 反向传播时通过 allgather 恢复完整激活值
-
cpu_checkpointing:将激活值卸载到 CPU
- 使用 pinned memory(锁页内存)加速 CPU↔GPU 传输
- 异步传输 (
non_blocking=True) 隐藏延迟 - 与
partition_activations配合使用,先分区再卸载
-
contiguous_memory_optimization:连续内存优化
- 预分配大的连续缓冲区,避免内存碎片
- 通过
number_checkpoints参数控制缓冲区数量 - 减少内存分配/释放的开销
5.2.2 核心代码解析
# DeepSpeed Activation Checkpointing 核心流程
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
# DeepSpeed 配置
deepspeed_config = {
"train_batch_size": 32,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5
}
},
"zero_optimization": {
"stage": 2, # ZeRO-2 优化器状态分区
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
}
},
# Activation Checkpointing 配置
"activation_checkpointing": {
"partition_activations": True, # 跨 GPU 分区激活值
"cpu_checkpointing": True, # 卸载到 CPU
"contiguous_memory_optimization": True,
"number_checkpoints": None, # 自动决定
"synchronize_checkpoint_boundary": False, # 同步边界
"profile": False
}
}
# 初始化模型
model = MyLargeTransformerModel()
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=deepspeed_config
)
# 在模型定义中使用 checkpoint
class DeepSpeedTransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = SelfAttention(config)
self.mlp = MLP(config)
def forward(self, hidden_states, attention_mask):
# DeepSpeed checkpoint 自动处理分区、卸载等逻辑
return deepspeed.checkpointing.checkpoint(
self._forward_impl,
hidden_states,
attention_mask
)
def _forward_impl(self, hidden_states, attention_mask):
attn_output = self.attention(hidden_states, attention_mask)
mlp_output = self.mlp(attn_output)
return mlp_outputexamples/deepspeed_checkpoint.py
性能数据:
| 配置 | 显存节省 | 训练速度 | 适用场景 |
|---|---|---|---|
| 标准 checkpoint | ~50% | 基准 | 单卡训练 |
| + partition_activations | ~50-70% | -5% | 多卡数据并行 |
| + cpu_checkpointing | ~70-85% | -15% | 显存极度受限 |
| 全部启用 | ~85-90% | -20% | 超大规模模型 |
5.3 Megatron-LM 详解
Megatron-LM 是 NVIDIA 开源的用于训练千亿参数级 Transformer 的框架,其 checkpoint 实现与 Tensor Parallelism (TP) 和 Pipeline Parallelism (PP) 深度集成。
5.3.1 三种粒度策略
Megatron 提供三种不同的 checkpoint 粒度,适应不同场景:
┌─────────────────────────────────────────────────────────────────────────┐
│ Megatron Checkpoint 粒度对比 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. FULL(全量)模式 │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Input ──► [Transformer Block] ──► Output │ │
│ │ / | \ │ │
│ │ 重算 重算 重算 ◄── 全部重计算 │ │
│ │ 只保存 Input │ │
│ │ 显存节省: ★★★★★ 速度损失: ★★★★☆ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 2. SELECTIVE(选择性)模式 │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Input ──► [Attention] ──► [FFN] ──► Output │ │
│ │ ↓保存 重算 │ │
│ │ 保留高效激活 重算内存密集部分 │ │
│ │ 显存节省: ★★★★☆ 速度损失: ★★★☆☆ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 3. SELECTIVE_OUTPUT 模式 │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Input ──► [SubModule] ──► Output │ │
│ │ ↓保存 │ │
│ │ 丢弃特定输出 │ │
│ │ 显存节省: ★★★☆☆ 速度损失: ★★☆☆☆ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
5.3.2 Tensor Parallelism 下的 Checkpoint
在 TP 场景下,Megatron 的 checkpoint 有特殊优化:
# Megatron-LM Activation Checkpointing 配置
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core import parallel_state
# 命令行参数配置
"""
--recompute-granularity full # 全量重计算
--recompute-activations # 启用选择性重计算
--recompute-granularity selective # 选择性模式
--recompute-modules attention ffn # 指定重计算模块
--checkpoint-num-layers 1 # 每层 checkpoint
--recompute-method uniform # 均匀分布检查点
--recompute-method block # 块级检查点
"""
# 在模型定义中
class MegatronTransformerLayer(TransformerLayer):
def __init__(self, config):
super().__init__(config)
self.recompute_granularity = config.recompute_granularity
self.checkpoint_core_attention = config.recompute_core_attention
def forward(self, hidden_states, attention_mask):
# Megatron 自动处理 TP 下的 checkpoint
if self.checkpoint_activations and self.training:
# use_reentrant=False 避免一些限制
output = torch.utils.checkpoint.checkpoint(
self._forward_impl,
hidden_states,
attention_mask,
use_reentrant=False
)
else:
output = self._forward_impl(hidden_states, attention_mask)
return output
def _forward_impl(self, hidden_states, attention_mask):
# 列并行 Linear + 行并行 Linear 组合
# 只需要 all-reduce 通信
attention_output = self.self_attention(
hidden_states,
attention_mask,
recompute_core_attention=self.checkpoint_core_attention
)
mlp_output = self.mlp(attention_output)
return mlp_outputexamples/megatron_checkpoint.py
5.3.3 Pipeline Parallelism 下的 Checkpoint
在 PP 场景下,Megatron 需要协调多个 stage 的 checkpoint:
┌─────────────────────────────────────────────────────────────────────────┐
│ Pipeline Parallelism 下的 Checkpoint 策略 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 3 Pipeline Stages, 每层一个 Transformer Block: │
│ │
│ Stage 0 (GPU 0) Stage 1 (GPU 1) Stage 2 (GPU 2) │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
│ │ [Block 0] │──────►│ [Block 1] │──────►│ [Block 2] │──► Loss │
│ │ │ │ │ │ │ │
│ │ checkpoint│ │ checkpoint│ │ checkpoint│ │
│ │ num_layers│ │ num_layers│ │ num_layers│ │
│ └───────────┘ └───────────┘ └───────────┘ │
│ │
│ 配置策略: │
│ • F-then-B schedule: 每个 micro-batch 单独 forward/backward │
│ • 1F1B schedule: 交错 forward/backward, checkpoint 需要与 stage 对齐 │
│ • Interleaved 1F1B: 虚拟 stage, 更复杂的 checkpoint 管理 │
│ │
│ checkpoint_num_layers 设计原理: │
│ • = 1: 每层都 checkpoint,显存最优,速度最慢 │
│ • = stage 层数: 每个 stage 只保存输入,类似 full 模式 │
│ • 根据显存预算动态调整 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
与 DeepSpeed 的关键差异:
| 特性 | DeepSpeed | Megatron-LM |
|---|---|---|
| 主要优化目标 | 显存效率最大化 | 并行效率最大化 |
| 并行支持 | 数据并行 (ZeRO) | TP + PP + DP 混合并行 |
| Checkpoint 粒度 | 层级别 | 模块级别(更细粒度控制) |
| 适用模型规模 | 中大规模 (1B-100B) | 超大规模 (100B+) |
| 使用复杂度 | 中等 | 较高 |
5.4 Unsloth 详解
Unsloth 是专注于 LLM 微调的框架,其 checkpoint 实现在 PyTorch 基础上增加了多项优化,主打易用性和极致显存效率。
5.4.1 “unsloth” 模式
Unsloth 提供了两种 checkpoint 模式:
True:使用 PyTorch 原生 checkpoint"unsloth":使用 Unsloth 优化版本,额外节省 30% VRAM
# Unsloth Gradient Checkpointing 使用示例
from unsloth import FastLanguageModel
import torch
# 加载模型
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/llama-3-8b-bnb-4bit",
max_seq_length=8192,
dtype=None, # 自动选择
load_in_4bit=True, # 4-bit 量化
)
# 配置 LoRA(可选,但推荐)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth", # 关键参数!
random_state=3407,
)
# "unsloth" 模式的优化包括:
# 1. 异步 CPU 卸载 - 将不急需的激活值异步转移到 CPU
# 2. 内存池复用 - 预分配内存池,避免频繁 malloc/free
# 3. 融合算子 - 减少中间激活值的产生
# 4. 超长上下文优化 - 支持 128K+ 上下文训练
# 训练配置
training_args = TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
num_train_epochs=3,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
# Unsloth 自动处理 checkpoint 细节
)
# 对比:PyTorch 原生 vs Unsloth 优化
# ┌──────────────────────┬─────────────┬──────────────┐
# │ 配置 │ 显存占用 │ 最大序列长度 │
# ├──────────────────────┼─────────────┼──────────────┤
# │ 无 checkpoint │ 100% │ 2K │
# │ checkpoint=True │ ~50% │ 4K │
# │ checkpoint="unsloth" │ ~35% │ 8K+ │
# └──────────────────────┴─────────────┴──────────────┘examples/unsloth_checkpoint.py
Unsloth 优化的核心机制:
- 智能内存池:预分配固定大小的内存缓冲区,激活值在池内循环复用
- 异步卸载流水线:利用 GPU 计算和 PCIe 传输的并行性
- 量化感知 checkpoint:对保存的激活值进行低精度量化
- Flash Attention 集成:与 Flash Attention 2/3 配合,进一步优化显存
5.5 框架选择指南
5.5.1 多维度对比
| 维度 | PyTorch 原生 | DeepSpeed | Megatron-LM | Unsloth |
|---|---|---|---|---|
| 上手难度 | ⭐ 简单 | ⭐⭐ 中等 | ⭐⭐⭐ 复杂 | ⭐ 简单 |
| 显存节省 | ~50% | ~70-85% | ~60-70% | ~65-80% |
| 速度损失 | ~20% | ~15-20% | ~10-25% | ~10% |
| 最大模型 | 10B | 100B+ | 1T+ | 70B |
| 并行支持 | DDP | ZeRO-DP | TP+PP+DP | DDP |
| 特色功能 | 灵活可控 | CPU 卸载 | 层间并行 | 超长上下文 |
| 推荐场景 | 入门学习 | 通用训练 | 超大规模 | 快速微调 |
5.5.2 决策流程图
┌─────────────────┐
│ 开始选择框架 │
└────────┬────────┘
│
┌──────────────┴──────────────┐
│ │
模型 > 100B? 只需要微调?
│ │
是 ───┴─── 否 是 ───┴─── 否
│ │
▼ ▼
┌───────────────┐ ┌───────────────┐
│ Megatron-LM │ │ Unsloth │
│ (TP+PP 必需) │ │ (最简单高效) │
└───────────────┘ └───────┬───────┘
│
显存极度受限且有多卡?
│
是 ───┴─── 否
│
▼
┌───────────────┐
│ DeepSpeed │
│ (ZeRO + CPU │
│ offload) │
└───────────────┘
│
▼
┌───────────────┐
│ PyTorch 原生 │
│ (简单直接) │
└───────────────┘
5.5.3 场景化推荐
| 场景 | 推荐框架 | 配置建议 |
|---|---|---|
| 学习/原型验证 | PyTorch 原生 | checkpoint(use_reentrant=False) |
| 单卡 7B/13B 微调 | Unsloth | use_gradient_checkpointing="unsloth" |
| 多卡 30B-70B 训练 | DeepSpeed | ZeRO-3 + partition_activations |
| 显存 < 24GB 训练 | DeepSpeed | cpu_checkpointing=True |
| 100B+ 预训练 | Megatron-LM | TP + PP + selective checkpoint |
| 超长上下文 (>32K) | Unsloth | 自动优化 |
为什么能省显存?原理简述
反向传播需要用到前向计算的激活值来计算梯度。对于一层 ,计算 需要知道该层的输入激活 。
正常训练:保存所有层的激活值 ,显存占用 。
Checkpointing:只保存检查点层的激活值。反向到某层时,从最近的检查点重新前向计算得到该层激活。显存占用 (常数,取决于检查点间隔)。
代价是前向计算执行了两次(一次正常前向,一次反向时的重算),所以训练时间增加约 20%~30%。
什么时候用?什么时候不用?
推荐使用:
- 显存不够,只能把 batch_size 设得很小(如 1)
- 模型参数量 > 1B,激活值是显存瓶颈
- 能接受训练时间增加 20%~30%
不推荐使用:
- 显存充足(GPU 利用率低)
- 推理阶段(不需要反向传播,checkpoint 无意义)
- 对延迟极其敏感的场景(如在线服务训练)
小结
Activation Checkpointing 是最简单有效的显存优化手段之一。核心思想是时间换空间:用 20% 的训练时间增加,换取 50%~70% 的显存节省。
对于初学者,建议从 PyTorch 原生的 torch.utils.checkpoint 开始。等熟悉后,可以尝试 DeepSpeed 或 Unsloth 来获得更好的优化效果。下一步可以学习 ZeRO 优化器状态分区、混合精度训练等技术,它们与 checkpoint 是正交的,可以叠加使用。