Skip to content
传衡博客
返回

Activation Checkpointing:用时间换显存的艺术

训练大模型时最崩溃的时刻,莫过于看到 RuntimeError: CUDA out of memoryActivation Checkpointing(激活检查点,也叫梯度检查点)能把显存占用降低 50%~70%,代价是训练时间增加约 20%。这篇文章会用一个简单的类比帮你建立直觉,然后手把手教你用起来。

参考资料
  1. Training Deep Nets with Sublinear Memory Cost - Chen et al., 2016
  2. PyTorch Checkpoint Documentation
  3. DeepSpeed Activation Checkpointing
  4. Megatron-LM Technical Report
  5. Unsloth Memory Optimization

显存都去哪了?

训练神经网络时,GPU 显存主要被四部分占用:

组成部分1B 模型示例说明
模型参数4 GB存储权重
梯度4 GB反向传播计算得到
优化器状态8 GBAdam 需要维护动量等状态
激活值32 GB前向计算的中间结果
合计48 GBbatch_size=4, seq_len=2048

激活值是最大的显存消耗者。对于 Transformer 模型,激活值显存与层数、序列长度、batch size 成正比。当模型变大或序列变长时,激活值会迅速吃掉所有显存。

草稿纸类比

想象你在做一道很长的数学题,需要多步计算。

正常训练的做法:每一步计算都把中间结果写在草稿纸上。这样回头检查时,所有中间结果都随手可得。问题是,草稿纸(显存)很快就用完了。

Activation Checkpointing 的做法:只在关键步骤(检查点)保存中间结果,其他步骤需要时重新计算。代价是多花点时间重算,但草稿纸用量大幅减少。

具体到神经网络:

流程对比:ASCII 图示

假设有一个 4 层的网络(Layer 1 → Layer 2 → Layer 3 → Layer 4):

正常训练

前向传播:
Input → [L1] → [L2] → [L3] → [L4] → Loss
        ↓      ↓      ↓      ↓
       保存   保存   保存   保存  ← 所有激活值都存

反向传播:
Loss ← [L4] ← [L3] ← [L2] ← [L1] ← Grad
        ↑      ↑      ↑      ↑
       读取   读取   读取   读取  ← 直接使用保存的激活值

显存占用:O(L) 随层数线性增长

Activation Checkpointing(以 Layer 2 为检查点)

前向传播:
Input → [L1] → [L2] → [L3] → [L4] → Loss
        ↓      ✓      ↓      ↓
       丢弃   保存   丢弃   丢弃  ← 只保存检查点

反向传播:
Loss ← [L4] ← [L3] ← [L2] ← [L1] ← Grad
               ↑      ↑
              重算 ← 从检查点重新前向到 L3
        ↑      ↑
       重算 ← 从 Input 重新前向到 L1

显存占用:O(1) 常数级别(只存检查点)

PyTorch 实战

PyTorch 提供了 torch.utils.checkpoint 来实现这一功能。

基础用法:一个简单的 MLP

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1000, 1000)
        self.layer2 = nn.Linear(1000, 1000)
        self.layer3 = nn.Linear(1000, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        # 普通前向:直接调用
        # x = self.relu(self.layer1(x))
        # x = self.relu(self.layer2(x))

        # Checkpoint 版本:用 checkpoint 包装
        x = checkpoint(self._forward_block, x, use_reentrant=False)
        return x

    def _forward_block(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x

# 显存对比测试
def test_memory():
    model = SimpleMLP().cuda()
    x = torch.randn(32, 1000, requires_grad=True).cuda()

    # 普通训练
    torch.cuda.reset_peak_memory_stats()
    out = model._forward_block(x)
    out.sum().backward()
    mem_normal = torch.cuda.max_memory_allocated() / 1024**2

    # Checkpoint 训练
    model.zero_grad(set_to_none=True)
    torch.cuda.reset_peak_memory_stats()
    out = model(x)
    out.sum().backward()
    mem_checkpoint = torch.cuda.max_memory_allocated() / 1024**2

    print(f"普通模式显存: {mem_normal:.1f} MB")
    print(f"Checkpoint 模式显存: {mem_checkpoint:.1f} MB")
    print(f"节省: {(1 - mem_checkpoint/mem_normal)*100:.1f}%")

test_memory()examples/basic_checkpoint.py

运行结果示例:

普通模式显存: 256.0 MB
Checkpoint 模式显存: 128.0 MB
节省: 50.0%

在 Transformer 中的应用

实际使用中,通常对整个 Transformer Block 应用 checkpoint:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # 自注意力
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # 前馈网络
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x

class CheckpointedTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_heads):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads)
            for _ in range(n_layers)
        ])

    def forward(self, x):
        for i, block in enumerate(self.blocks):
            # 每层都使用 checkpoint
            x = checkpoint(
                block,
                x,
                use_reentrant=False,  # 推荐设为 False,避免一些限制
                preserve_rng_state=False  # 如果不需要 dropout 随机性
            )
        return xexamples/transformer_checkpoint.py

使用注意事项

情况建议
use_reentrant=False推荐。支持梯度缩放、更少的限制
随机性(dropout)需要确定性时设置 preserve_rng_state=True
检查点粒度太细(每层都 checkpoint)会增加重算开销;太粗(整个模型)节省显存少
哪些层适合 checkpoint参数量小、计算快的层(如 Attention、FFN)
哪些层不适合参数量大的层(如嵌入层)、输入输出层

主流框架实现深度对比

不同深度学习框架对 Activation Checkpointing 的实现各有特色。本节深入剖析 DeepSpeed、Megatron-LM 和 Unsloth 三种主流框架的实现机制,帮助你根据实际需求做出选择。

5.1 实现架构概览

三种框架在 PyTorch 原生 checkpoint 基础上,针对大规模训练场景进行了不同程度的扩展:

┌─────────────────────────────────────────────────────────────────────────┐
│                         框架架构对比图                                    │
├─────────────────┬─────────────────┬─────────────────┬─────────────────────┤
│   PyTorch 原生   │    DeepSpeed    │   Megatron-LM   │      Unsloth        │
├─────────────────┼─────────────────┼─────────────────┼─────────────────────┤
│                 │  ┌───────────┐  │  ┌───────────┐  │   ┌───────────┐     │
│  checkpoint()   │  │ 与 ZeRO   │  │  │ Tensor/   │  │   │ 内存池    │     │
│       ↓         │  │ 深度集成   │  │  │ Pipeline  │  │   │ 异步卸载  │     │
│  Checkpoint     │  │ 协同优化   │  │  │ Parallel  │  │   │ 融合算子  │     │
│  Function       │  └───────────┘  │  │ 协同       │  │   └───────────┘     │
│                 │                 │  └───────────┘  │                     │
│  • 基础功能      │  • 分区激活值    │  • 三种粒度    │   • 30% 额外节省     │
│  • 灵活可控      │  • CPU 卸载     │  • 层间并行    │   • 超长上下文       │
│  • 简单直接      │  • 连续内存     │  • 超大规模    │   • 易用性强         │
└─────────────────┴─────────────────┴─────────────────┴─────────────────────┘

设计哲学差异

5.2 DeepSpeed 详解

DeepSpeed 的 Activation Checkpointing 位于 deepspeed.runtime.activation_checkpointing.checkpointing 模块,核心类包括 CheckpointFunctionCudaRNGStatesTracker

5.2.1 与 ZeRO 的协同机制

DeepSpeed 的独特优势在于 checkpoint 与 ZeRO(Zero Redundancy Optimizer)的深度集成:

┌─────────────────────────────────────────────────────────────────────────┐
│                    DeepSpeed ZeRO + Checkpoint 协同架构                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   ┌──────────────┐     ┌──────────────┐     ┌──────────────┐           │
│   │   GPU 0      │     │   GPU 1      │     │   GPU 2      │           │
│   │ ┌──────────┐ │     │ ┌──────────┐ │     │ ┌──────────┐ │           │
│   │ │ ZeRO-2/3 │ │     │ │ ZeRO-2/3 │ │     │ │ ZeRO-2/3 │ │  优化器    │
│   │ │ 参数分区  │ │     │ │ 参数分区  │ │     │ │ 参数分区  │ │  状态分区  │
│   │ └──────────┘ │     │ └──────────┘ │     │ └──────────┘ │           │
│   │ ┌──────────┐ │     │ ┌──────────┐ │     │ ┌──────────┐ │           │
│   │ │ 激活分区  │ │◄────┼►│ 激活分区  │ │◄────┼►│ 激活分区  │ │  Checkpoint │
│   │ │ (1/N)    │ │     │ │ (1/N)    │ │     │ │ (1/N)    │ │  分区激活   │
│   │ └──────────┘ │     │ └──────────┘ │     │ └──────────┘ │           │
│   │ ┌──────────┐ │     │ ┌──────────┐ │     │ ┌──────────┐ │           │
│   │ │ CPU 缓冲区│ │     │ │ CPU 缓冲区│ │     │ │ CPU 缓冲区│ │  CPU      │
│   │ │(Pinned)  │ │     │ │(Pinned)  │ │     │ │(Pinned)  │ │  Offload  │
│   │ └──────────┘ │     │ └──────────┘ │     │ └──────────┘ │           │
│   └──────────────┘     └──────────────┘     └──────────────┘           │
│                                                                         │
│   总显存占用 ≈ 模型参数/N + 优化器状态/N + 激活值/N + 梯度/N            │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

关键特性详解

  1. partition_activations:跨 GPU 分区激活值

    • 使用模型并行通信原语(allgather, reduce_scatter)
    • 每个 GPU 只存储 1/N 的激活值(N 为 GPU 数量)
    • 反向传播时通过 allgather 恢复完整激活值
  2. cpu_checkpointing:将激活值卸载到 CPU

    • 使用 pinned memory(锁页内存)加速 CPU↔GPU 传输
    • 异步传输 (non_blocking=True) 隐藏延迟
    • partition_activations 配合使用,先分区再卸载
  3. contiguous_memory_optimization:连续内存优化

    • 预分配大的连续缓冲区,避免内存碎片
    • 通过 number_checkpoints 参数控制缓冲区数量
    • 减少内存分配/释放的开销

5.2.2 核心代码解析

# DeepSpeed Activation Checkpointing 核心流程

import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint

# DeepSpeed 配置
deepspeed_config = {
    "train_batch_size": 32,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5
        }
    },
    "zero_optimization": {
        "stage": 2,  # ZeRO-2 优化器状态分区
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        }
    },
    # Activation Checkpointing 配置
    "activation_checkpointing": {
        "partition_activations": True,      # 跨 GPU 分区激活值
        "cpu_checkpointing": True,          # 卸载到 CPU
        "contiguous_memory_optimization": True,
        "number_checkpoints": None,          # 自动决定
        "synchronize_checkpoint_boundary": False,  # 同步边界
        "profile": False
    }
}

# 初始化模型
model = MyLargeTransformerModel()
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config=deepspeed_config
)

# 在模型定义中使用 checkpoint
class DeepSpeedTransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = SelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, hidden_states, attention_mask):
        # DeepSpeed checkpoint 自动处理分区、卸载等逻辑
        return deepspeed.checkpointing.checkpoint(
            self._forward_impl,
            hidden_states,
            attention_mask
        )

    def _forward_impl(self, hidden_states, attention_mask):
        attn_output = self.attention(hidden_states, attention_mask)
        mlp_output = self.mlp(attn_output)
        return mlp_outputexamples/deepspeed_checkpoint.py

性能数据

配置显存节省训练速度适用场景
标准 checkpoint~50%基准单卡训练
+ partition_activations~50-70%-5%多卡数据并行
+ cpu_checkpointing~70-85%-15%显存极度受限
全部启用~85-90%-20%超大规模模型

5.3 Megatron-LM 详解

Megatron-LM 是 NVIDIA 开源的用于训练千亿参数级 Transformer 的框架,其 checkpoint 实现与 Tensor Parallelism (TP) 和 Pipeline Parallelism (PP) 深度集成。

5.3.1 三种粒度策略

Megatron 提供三种不同的 checkpoint 粒度,适应不同场景:

┌─────────────────────────────────────────────────────────────────────────┐
│                   Megatron Checkpoint 粒度对比                           │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  1. FULL(全量)模式                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  Input ──► [Transformer Block] ──► Output                       │   │
│  │            /       |        \                                   │   │
│  │          重算    重算      重算  ◄── 全部重计算                  │   │
│  │          只保存 Input                                         │   │
│  │  显存节省: ★★★★★  速度损失: ★★★★☆                               │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  2. SELECTIVE(选择性)模式                                              │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  Input ──► [Attention] ──► [FFN] ──► Output                     │   │
│  │            ↓保存          重算                                  │   │
│  │          保留高效激活    重算内存密集部分                         │   │
│  │  显存节省: ★★★★☆  速度损失: ★★★☆☆                               │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  3. SELECTIVE_OUTPUT 模式                                               │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  Input ──► [SubModule] ──► Output                               │   │
│  │                         ↓保存                                   │   │
│  │                      丢弃特定输出                               │   │
│  │  显存节省: ★★★☆☆  速度损失: ★★☆☆☆                               │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

5.3.2 Tensor Parallelism 下的 Checkpoint

在 TP 场景下,Megatron 的 checkpoint 有特殊优化:

# Megatron-LM Activation Checkpointing 配置

from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core import parallel_state

# 命令行参数配置
"""
--recompute-granularity full          # 全量重计算
--recompute-activations               # 启用选择性重计算
--recompute-granularity selective     # 选择性模式
--recompute-modules attention ffn     # 指定重计算模块
--checkpoint-num-layers 1             # 每层 checkpoint
--recompute-method uniform            # 均匀分布检查点
--recompute-method block              # 块级检查点
"""

# 在模型定义中
class MegatronTransformerLayer(TransformerLayer):
    def __init__(self, config):
        super().__init__(config)
        self.recompute_granularity = config.recompute_granularity
        self.checkpoint_core_attention = config.recompute_core_attention

    def forward(self, hidden_states, attention_mask):
        # Megatron 自动处理 TP 下的 checkpoint
        if self.checkpoint_activations and self.training:
            # use_reentrant=False 避免一些限制
            output = torch.utils.checkpoint.checkpoint(
                self._forward_impl,
                hidden_states,
                attention_mask,
                use_reentrant=False
            )
        else:
            output = self._forward_impl(hidden_states, attention_mask)
        return output

    def _forward_impl(self, hidden_states, attention_mask):
        # 列并行 Linear + 行并行 Linear 组合
        # 只需要 all-reduce 通信
        attention_output = self.self_attention(
            hidden_states,
            attention_mask,
            recompute_core_attention=self.checkpoint_core_attention
        )
        mlp_output = self.mlp(attention_output)
        return mlp_outputexamples/megatron_checkpoint.py

5.3.3 Pipeline Parallelism 下的 Checkpoint

在 PP 场景下,Megatron 需要协调多个 stage 的 checkpoint:

┌─────────────────────────────────────────────────────────────────────────┐
│              Pipeline Parallelism 下的 Checkpoint 策略                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  3 Pipeline Stages, 每层一个 Transformer Block:                          │
│                                                                         │
│  Stage 0 (GPU 0)      Stage 1 (GPU 1)      Stage 2 (GPU 2)             │
│  ┌───────────┐       ┌───────────┐       ┌───────────┐                 │
│  │ [Block 0] │──────►│ [Block 1] │──────►│ [Block 2] │──► Loss         │
│  │           │       │           │       │           │                 │
│  │ checkpoint│       │ checkpoint│       │ checkpoint│                 │
│  │ num_layers│       │ num_layers│       │ num_layers│                 │
│  └───────────┘       └───────────┘       └───────────┘                 │
│                                                                         │
│  配置策略:                                                               │
│  • F-then-B schedule: 每个 micro-batch 单独 forward/backward            │
│  • 1F1B schedule: 交错 forward/backward, checkpoint 需要与 stage 对齐   │
│  • Interleaved 1F1B: 虚拟 stage, 更复杂的 checkpoint 管理                │
│                                                                         │
│  checkpoint_num_layers 设计原理:                                        │
│  • = 1: 每层都 checkpoint,显存最优,速度最慢                            │
│  • = stage 层数: 每个 stage 只保存输入,类似 full 模式                   │
│  • 根据显存预算动态调整                                                  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

与 DeepSpeed 的关键差异

特性DeepSpeedMegatron-LM
主要优化目标显存效率最大化并行效率最大化
并行支持数据并行 (ZeRO)TP + PP + DP 混合并行
Checkpoint 粒度层级别模块级别(更细粒度控制)
适用模型规模中大规模 (1B-100B)超大规模 (100B+)
使用复杂度中等较高

5.4 Unsloth 详解

Unsloth 是专注于 LLM 微调的框架,其 checkpoint 实现在 PyTorch 基础上增加了多项优化,主打易用性和极致显存效率

5.4.1 “unsloth” 模式

Unsloth 提供了两种 checkpoint 模式:

# Unsloth Gradient Checkpointing 使用示例

from unsloth import FastLanguageModel
import torch

# 加载模型
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/llama-3-8b-bnb-4bit",
    max_seq_length=8192,
    dtype=None,  # 自动选择
    load_in_4bit=True,  # 4-bit 量化
)

# 配置 LoRA(可选,但推荐)
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",  # 关键参数!
    random_state=3407,
)

# "unsloth" 模式的优化包括:
# 1. 异步 CPU 卸载 - 将不急需的激活值异步转移到 CPU
# 2. 内存池复用 - 预分配内存池,避免频繁 malloc/free
# 3. 融合算子 - 减少中间激活值的产生
# 4. 超长上下文优化 - 支持 128K+ 上下文训练

# 训练配置
training_args = TrainingArguments(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    # Unsloth 自动处理 checkpoint 细节
)

# 对比:PyTorch 原生 vs Unsloth 优化
# ┌──────────────────────┬─────────────┬──────────────┐
# │ 配置                  │ 显存占用    │ 最大序列长度  │
# ├──────────────────────┼─────────────┼──────────────┤
# │ 无 checkpoint        │ 100%        │ 2K           │
# │ checkpoint=True      │ ~50%        │ 4K           │
# │ checkpoint="unsloth" │ ~35%        │ 8K+          │
# └──────────────────────┴─────────────┴──────────────┘examples/unsloth_checkpoint.py

Unsloth 优化的核心机制

  1. 智能内存池:预分配固定大小的内存缓冲区,激活值在池内循环复用
  2. 异步卸载流水线:利用 GPU 计算和 PCIe 传输的并行性
  3. 量化感知 checkpoint:对保存的激活值进行低精度量化
  4. Flash Attention 集成:与 Flash Attention 2/3 配合,进一步优化显存

5.5 框架选择指南

5.5.1 多维度对比

维度PyTorch 原生DeepSpeedMegatron-LMUnsloth
上手难度⭐ 简单⭐⭐ 中等⭐⭐⭐ 复杂⭐ 简单
显存节省~50%~70-85%~60-70%~65-80%
速度损失~20%~15-20%~10-25%~10%
最大模型10B100B+1T+70B
并行支持DDPZeRO-DPTP+PP+DPDDP
特色功能灵活可控CPU 卸载层间并行超长上下文
推荐场景入门学习通用训练超大规模快速微调

5.5.2 决策流程图

                          ┌─────────────────┐
                          │ 开始选择框架     │
                          └────────┬────────┘

                    ┌──────────────┴──────────────┐
                    │                             │
              模型 > 100B?                   只需要微调?
                    │                             │
              是 ───┴─── 否                是 ───┴─── 否
              │                             │
              ▼                             ▼
      ┌───────────────┐              ┌───────────────┐
      │ Megatron-LM   │              │   Unsloth     │
      │ (TP+PP 必需)  │              │ (最简单高效)   │
      └───────────────┘              └───────┬───────┘

                              显存极度受限且有多卡?

                                       是 ───┴─── 否


                              ┌───────────────┐
                              │   DeepSpeed   │
                              │ (ZeRO + CPU   │
                              │  offload)     │
                              └───────────────┘


                              ┌───────────────┐
                              │ PyTorch 原生  │
                              │ (简单直接)    │
                              └───────────────┘

5.5.3 场景化推荐

场景推荐框架配置建议
学习/原型验证PyTorch 原生checkpoint(use_reentrant=False)
单卡 7B/13B 微调Unslothuse_gradient_checkpointing="unsloth"
多卡 30B-70B 训练DeepSpeedZeRO-3 + partition_activations
显存 < 24GB 训练DeepSpeedcpu_checkpointing=True
100B+ 预训练Megatron-LMTP + PP + selective checkpoint
超长上下文 (>32K)Unsloth自动优化

为什么能省显存?原理简述

反向传播需要用到前向计算的激活值来计算梯度。对于一层 ll,计算 LWl\frac{\partial L}{\partial W_l} 需要知道该层的输入激活 ala_l

正常训练:保存所有层的激活值 a1,a2,...,aLa_1, a_2, ..., a_L,显存占用 O(L)O(L)

Checkpointing:只保存检查点层的激活值。反向到某层时,从最近的检查点重新前向计算得到该层激活。显存占用 O(1)O(1)(常数,取决于检查点间隔)。

代价是前向计算执行了两次(一次正常前向,一次反向时的重算),所以训练时间增加约 20%~30%。

什么时候用?什么时候不用?

推荐使用

不推荐使用

小结

Activation Checkpointing 是最简单有效的显存优化手段之一。核心思想是时间换空间:用 20% 的训练时间增加,换取 50%~70% 的显存节省。

对于初学者,建议从 PyTorch 原生的 torch.utils.checkpoint 开始。等熟悉后,可以尝试 DeepSpeed 或 Unsloth 来获得更好的优化效果。下一步可以学习 ZeRO 优化器状态分区、混合精度训练等技术,它们与 checkpoint 是正交的,可以叠加使用。



Previous Post
手撕大模型核心算子
Next Post
MHA vs MQA vs GQA vs MLA:四种 Attention 机制显存与性能全对比