Skip to content
传衡博客
返回

手撕大模型核心算子

大模型的核心能力来自这些基础算子的组合。面试中手撕 Attention、写 LayerNorm 是常规操作,但很多人只背了公式,没真正动手实现过。本文从激活函数到注意力机制,逐个拆解 PyTorch 版的精简实现,标注面试高频考点。

激活函数

Sigmoid

Sigmoid 将任意实数映射到 (0, 1),适合表示概率。现代 LLM 已很少直接使用,但在门控机制(如 LSTM、GRU)中仍有应用。

σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}

导数有一个优美的性质:σ(x)=σ(x)(1σ(x))\sigma'(x) = \sigma(x)(1 - \sigma(x)),这使得反向传播计算非常高效。

import torch
import torch.nn as nn

class Sigmoid(nn.Module):
    """Sigmoid 激活函数"""
    def forward(self, x):
        return 1 / (1 + torch.exp(-x))

# 验证
x = torch.randn(3, 4)
my_sigmoid = Sigmoid()
torch_sigmoid = nn.Sigmoid()

assert torch.allclose(my_sigmoid(x), torch_sigmoid(x), atol=1e-6)
print(f"最大误差: {(my_sigmoid(x) - torch_sigmoid(x)).abs().max():.2e}")

面试高频: Sigmoid 的梯度消失问题是如何产生的?

当输入绝对值较大时(x > 5 或 x < -5),Sigmoid 输出接近 0 或 1,导数接近 0。多层堆叠后,梯度连乘导致梯度消失。这也是为什么深度网络更偏好 ReLU 及其变体。

GELU

GELU(Gaussian Error Linear Unit)是 BERT、GPT 系列的标准激活。相比 ReLU 的硬截断,GELU 提供平滑的非线性变换。

精确计算依赖误差函数 erf:

GELU(x)=xΦ(x)=x12[1+erf(x2)]\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]

工程上更常用 tanh 近似(Hendrycks 2016):

GELU(x)0.5x(1+tanh[2π(x+0.044715x3)])\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)
class GELU(nn.Module):
    """GELU 激活函数(tanh 近似版)"""
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

# 验证
x = torch.randn(3, 4)
my_gelu = GELU()
torch_gelu = nn.GELU(approximate='tanh')

assert torch.allclose(my_gelu(x), torch_gelu(x), atol=1e-5)
print(f"最大误差: {(my_gelu(x) - torch_gelu(x)).abs().max():.2e}")

面试高频: GELU 与 ReLU 的本质区别是什么?

ReLU 是确定性的硬阈值(x < 0 则输出 0),GELU 则是概率性的软过渡。从随机正则化角度理解:GELU 等价于对输入施加 Dropout 的期望输出。这种平滑性让优化更稳定,尤其在深层网络中。

SiLU / Swish

SiLU(Sigmoid Linear Unit),也叫 Swish,是 LLaMA、Mistral 等现代 LLM 的首选激活。Google Brain 2017 年提出,发现其性能在深层网络上优于 ReLU。

SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x)

结构上是输入与 sigmoid 门的逐元素乘积,实现了自门控(self-gated)机制。

class SiLU(nn.Module):
    """SiLU / Swish 激活函数"""
    def forward(self, x):
        return x * torch.sigmoid(x)

# 验证
x = torch.randn(3, 4)
my_silu = SiLU()
torch_silu = nn.SiLU()

assert torch.allclose(my_silu(x), torch_silu(x), atol=1e-6)
print(f"最大误差: {(my_silu(x) - torch_silu(x)).abs().max():.2e}")

面试高频: LLaMA 为什么选 SiLU 而不是 GELU?

两者性能差距不大,但 SiLU 计算量略小(没有立方运算)。PaLM、LLaMA-2 的实验表明,在足够大的模型和数据上,SiLU 的收敛速度稍快。这属于工程实践中的细微取舍。

归一化层

BatchNorm vs LayerNorm vs RMSNorm

三种归一化的核心差异在于沿着哪个维度计算统计量

归一化计算维度适用场景大模型使用
BatchNorm(N, H, W) 的 batch 维CNN几乎不用
LayerNorm(H) 的特征维Transformer标配
RMSNorm(H) 的特征维(仅 std)现代 LLMLLaMA/Qwen 首选

BatchNorm 依赖 batch 统计量,在序列长度变长的 Transformer 中不稳定。LayerNorm 和 RMSNorm 对每个样本独立计算,更适合变长序列。

LayerNorm

LayerNorm 对每个样本的所有特征做归一化,消除样本间的分布差异。

μ=1Hi=1Hxiσ2=1Hi=1H(xiμ)2y=γxμσ2+ϵ+β\begin{aligned} \mu &= \frac{1}{H}\sum_{i=1}^H x_i \\ \sigma^2 &= \frac{1}{H}\sum_{i=1}^H (x_i - \mu)^2 \\ y &= \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \end{aligned}

γ\gammaβ\beta 是可学习的缩放和平移参数。

数值示例:假设有一个 4 维输入向量,γ=1,β=0\gamma=1, \beta=0

步骤计算结果
原始输入xx[1.0, 2.0, 3.0, 4.0]
均值μ=1+2+3+44\mu = \frac{1+2+3+4}{4}2.5
方差σ2=(xiμ)24\sigma^2 = \frac{\sum(x_i-\mu)^2}{4}1.25
归一化xμσ2\frac{x-\mu}{\sqrt{\sigma^2}}[-1.34, -0.45, 0.45, 1.34]
输出γnorm+β\gamma \cdot \text{norm} + \beta[-1.34, -0.45, 0.45, 1.34]

归一化后均值为 0,方差为 1。实际训练中 γ\gammaβ\beta 会被学习成适合当前层的分布。

class LayerNorm(nn.Module):
    """Layer Normalization"""
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        # x: (..., normalized_shape)
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

# 验证数值示例
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
ln = LayerNorm(4)

with torch.no_grad():
    # 手动计算验证
    mean = x.mean()           # 2.5
    var = x.var(unbiased=False)  # 1.25
    expected = (x - mean) / torch.sqrt(var + 1e-5)
    print(f"均值: {mean:.2f}, 方差: {var:.2f}")
    print(f"预期归一化: {expected}")
    print(f"实际输出: {ln(x)}")

# 与官方实现对比
ln = LayerNorm(64)
torch_ln = nn.LayerNorm(64)
x = torch.randn(2, 10, 64)
with torch.no_grad():
    torch_ln.weight.copy_(ln.gamma)
    torch_ln.bias.copy_(ln.beta)
assert torch.allclose(ln(x), torch_ln(x), atol=1e-5)
print("LayerNorm 验证通过")

面试高频: Transformer 为什么用 LayerNorm 而不是 BatchNorm?

  1. 序列长度可变,BatchNorm 的 batch 统计量不稳定
  2. 推理时 BatchNorm 需要维护 running statistics,LayerNorm 训练和测试行为一致
  3. LayerNorm 对样本独立计算,更适合自回归生成场景

RMSNorm

RMSNorm(Root Mean Square LayerNorm)去掉了 LayerNorm 的 mean 计算,只保留 RMS 缩放。LLaMA-2 技术报告提到这能略微提升训练稳定性。

y=x1Hi=1Hxi2+ϵγy = \frac{x}{\sqrt{\frac{1}{H}\sum_{i=1}^H x_i^2 + \epsilon}} \cdot \gamma

数值对比:使用与 LayerNorm 示例相同的输入 [1.0, 2.0, 3.0, 4.0]:

归一化计算输出结果
LayerNormxμσ\frac{x-\mu}{\sigma}[-1.34, -0.45, 0.45, 1.34]
RMSNormxRMS\frac{x}{\text{RMS}}[0.37, 0.73, 1.10, 1.46]

注意 RMSNorm 不中心化为 0,而是保留原始数值的相对大小。对于 [1,2,3,4] 这种全正输入,RMSNorm 输出也都是正数。

class RMSNorm(nn.Module):
    """RMS Normalization"""
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(normalized_shape))

    def forward(self, x):
        # 只计算 RMS,不做中心化
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return self.weight * x_norm

# 验证数值示例
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
rms = RMSNorm(4)

with torch.no_grad():
    rms_val = torch.sqrt(torch.mean(x ** 2))  # sqrt(30/4) = 2.74
    expected = x / rms_val
    print(f"RMS: {rms_val:.2f}")
    print(f"预期输出: {expected}")
    print(f"实际输出: {rms(x)}")

# 验证与手动计算一致
rms = RMSNorm(64)
x = torch.randn(2, 10, 64)
output = rms(x)
expected_rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-5)
expected = x / expected_rms * rms.weight
assert torch.allclose(output, expected, atol=1e-5)
print("RMSNorm 验证通过")

面试高频: RMSNorm 去掉 mean 计算有什么好处?

  1. 计算量减少约 30%(省去一次求和)
  2. 对绝对位置信息更敏感,某些任务上有轻微精度提升
  3. 在 LLaMA 系列中被验证与 LayerNorm 效果相当甚至更优

Softmax

基础 Softmax

Softmax 将 logits 转换为概率分布,是注意力机制和分类输出的标配。

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}
def softmax(x, dim=-1):
    """基础 Softmax"""
    exp_x = torch.exp(x)
    return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)

# 数值稳定性测试
x = torch.tensor([1000.0, 1001.0, 1002.0])
try:
    result = softmax(x)
    print(f"结果: {result}")
except Exception as e:
    print(f"溢出: {e}")

Safe Softmax

直接计算 exie^{x_i} 会导致数值溢出(x > 709 时 double 也会溢出)。解决方案是减去最大值:

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}

数学上等价,因为分子分母同时除以 emax(x)e^{\max(x)} 不改变结果。

def safe_softmax(x, dim=-1):
    """数值稳定的 Softmax"""
    x_max = torch.max(x, dim=dim, keepdim=True).values
    exp_x = torch.exp(x - x_max)
    return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)

# 验证
x = torch.tensor([1000.0, 1001.0, 1002.0])
result = safe_softmax(x)
print(f"Safe Softmax 结果: {result}")
print(f"和为 1: {result.sum():.6f}")

# 与官方对比
x = torch.randn(3, 4)
assert torch.allclose(safe_softmax(x), torch.softmax(x, dim=-1), atol=1e-5)
print("Safe Softmax 验证通过")

面试高频: Softmax 的数值不稳定问题如何解决?时间复杂度是多少?

减去最大值保证指数运算不溢出,结果数学等价。时间复杂度 O(n),空间复杂度 O(n)(需要存储中间结果)。实际工程中还经常用 online softmax 进一步减少内存访问。

注意力机制

Self-Attention

Self-Attention 是 Transformer 的核心,让序列中的每个位置都能关注其他所有位置。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
class SelfAttention(nn.Module):
    """基础 Self-Attention"""
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.scale = torch.sqrt(torch.tensor(d_model, dtype=torch.float32))

    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # 应用 mask(如有)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

# 验证
attn = SelfAttention(64)
x = torch.randn(2, 10, 64)
output, weights = attn(x)
print(f"输出形状: {output.shape}, 注意力权重形状: {weights.shape}")

面试高频: 为什么要除以 dk\sqrt{d_k}

dkd_k 较大时,QKTQK^T 的方差也会增大,导致 softmax 进入梯度极小的饱和区。缩放因子 dk\sqrt{d_k} 将方差控制在合理范围,保证梯度流动。这是 Attention Is All You Need 论文中的关键设计。

Cross-Attention

Cross-Attention 的 Query 来自一个序列,Key/Value 来自另一个序列,用于 Encoder-Decoder 架构和多模态融合。

class CrossAttention(nn.Module):
    """Cross Attention: Q from x, KV from context"""
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.scale = torch.sqrt(torch.tensor(d_model))

    def forward(self, x, context):
        # x: 目标序列 (batch, tgt_len, d_model)
        # context: 源序列 (batch, src_len, d_model)
        Q = self.W_q(x)
        K = self.W_k(context)
        V = self.W_v(context)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output

# 验证:用于机器翻译场景
cross_attn = CrossAttention(64)
target = torch.randn(2, 5, 64)   #  decoder 当前状态
source = torch.randn(2, 10, 64)  # encoder 输出
output = cross_attn(target, source)
print(f"Cross-Attention 输出形状: {output.shape}")

面试高频: Cross-Attention 的典型应用场景?

  1. 机器翻译:decoder 关注 encoder 输出的源语言表示
  2. 多模态:视觉特征作为 KV,文本 Query 关注图像区域
  3. T5 等 Encoder-Decoder 模型:连接编码器和解码器

Multi-Head Attention

Multi-Head Attention 将输入投影到多个子空间分别计算注意力,增强模型表达能力。

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        # (batch, seq, d_model) -> (batch, num_heads, seq, d_k)
        batch, seq_len, _ = x.shape
        return x.view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape

        # 线性投影并分头
        Q = self.split_heads(self.W_q(x))
        K = self.split_heads(self.W_k(x))
        V = self.split_heads(self.W_v(x))

        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # 合并多头并投影
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        return self.W_o(attn_output)

# 验证
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)
output = mha(x)
print(f"MHA 输出形状: {output.shape}")

面试高频: 多头的直觉理解是什么?参数量如何计算?

每个 head 关注不同的特征子空间(语法 vs 语义、局部 vs 全局)。参数量:4 个投影矩阵 Wq,Wk,Wv,WoW_q, W_k, W_v, W_o,每个 dmodel×dmodeld_{model} \times d_{model},总参数量 4dmodel24d_{model}^2

KV Cache

KV Cache 是自回归生成的核心优化,缓存历史 KV 避免重复计算,将复杂度从 O(n3)O(n^3) 降到 O(n2)O(n^2)

class KVCacheAttention(nn.Module):
    """带 KV Cache 的 Attention(推理优化)"""
    def __init__(self, d_model, num_heads, max_batch_size, max_seq_len):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        # 预分配 KV Cache
        self.register_buffer(
            'k_cache',
            torch.zeros(max_batch_size, num_heads, max_seq_len, self.d_k)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(max_batch_size, num_heads, max_seq_len, self.d_k)
        )
        self.cache_len = 0

    def forward(self, x, start_pos=0):
        """
        x: (batch, 1, d_model) 只输入当前 token
        start_pos: 当前 token 在序列中的位置
        """
        batch, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 更新 cache
        self.k_cache[:batch, :, start_pos:start_pos+seq_len] = K
        self.v_cache[:batch, :, start_pos:start_pos+seq_len] = V

        # 使用全部缓存计算注意力
        K_full = self.k_cache[:batch, :, :start_pos+seq_len]
        V_full = self.v_cache[:batch, :, :start_pos+seq_len]

        scores = torch.matmul(Q, K_full.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V_full)

        output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        return self.W_o(output)

# 演示逐步生成
cache_attn = KVCacheAttention(64, 8, max_batch_size=2, max_seq_len=100)

# 模拟生成过程
for i in range(5):
    token = torch.randn(2, 1, 64)  # 每次只输入一个新 token
    output = cache_attn(token, start_pos=i)
    print(f"步骤 {i}: 输出形状 {output.shape}")

面试高频: 为什么只缓存 KV 不缓存 Q?显存占用如何计算?

自回归生成时,当前 token 只作为 Query,需要与所有历史 KV 计算注意力。Query 不需要保留,用完即弃。显存占用:2×batch×num_heads×seq_len×dk2 \times batch \times num\_heads \times seq\_len \times d_k,对于 7B 模型、4K 上下文约需 2GB。

RoPE 旋转位置编码

RoPE(Rotary Position Embedding)通过旋转矩阵注入位置信息,是 LLaMA、Qwen 等现代 LLM 的标准配置。

核心思想:将相邻维度组成一对,应用二维旋转矩阵,旋转角度与位置索引成正比。

[x2ix2i+1][cosmθisinmθisinmθicosmθi]\begin{bmatrix} x_{2i} \\ x_{2i+1} \end{bmatrix} \cdot \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix}

其中 θi=100002i/d\theta_i = 10000^{-2i/d}mm 是位置索引。

class RoPE(nn.Module):
    """旋转位置编码 (Rotary Position Embedding)"""
    def __init__(self, d_model, max_seq_len=2048, base=10000):
        super().__init__()
        self.d_model = d_model

        # 计算旋转角度 theta_i = base^(-2i/d)
        inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置对应的 cos/sin
        positions = torch.arange(max_seq_len)
        angles = torch.outer(positions, inv_freq)  # (max_seq_len, d_model/2)

        self.register_buffer('cos', torch.cos(angles))
        self.register_buffer('sin', torch.sin(angles))

    def forward(self, x, seq_len):
        """
        x: (batch, num_heads, seq_len, d_k)
        应用旋转位置编码
        """
        # 分离奇偶维度
        x1 = x[..., ::2]   # 偶数维
        x2 = x[..., 1::2]  # 奇数维

        # 应用旋转
        cos = self.cos[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin[:seq_len].unsqueeze(0).unsqueeze(0)

        rotated_x1 = x1 * cos - x2 * sin
        rotated_x2 = x1 * sin + x2 * cos

        # 交错合并
        output = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
        return output

# 验证
rope = RoPE(d_model=64)
x = torch.randn(2, 8, 10, 64)  # (batch, heads, seq, d_k)
output = rope(x, seq_len=10)
print(f"RoPE 输出形状: {output.shape}")

# 验证相对位置特性:相同相对位置的旋转角度相同
x1 = torch.randn(1, 1, 1, 64)
x2 = torch.randn(1, 1, 1, 64)

# 位置 5 的 token 和位置 3 的 token
rope_5_3 = rope(torch.cat([x1, x2], dim=2), seq_len=2)
# 位置 7 的 token 和位置 5 的 token
rope_7_5 = rope(torch.cat([x1, x2], dim=2), seq_len=2)

# 相对位置都是 2,点积应该相同
dot_5_3 = torch.sum(rope_5_3[0, 0, 0] * rope_5_3[0, 0, 1])
dot_7_5 = torch.sum(rope_7_5[0, 0, 0] * rope_7_5[0, 0, 1])
print(f"相对位置编码一致性: {torch.allclose(dot_5_3, dot_7_5, atol=1e-5)}")

面试高频: RoPE 相比绝对位置编码的优势?外推性问题如何解决?

RoPE 天然支持相对位置:两个 token 的注意力分数只依赖它们的相对距离,与绝对位置无关。这带来更好的长度外推性。训练时未见过的长序列上,RoPE 表现优于正弦/可学习位置编码。外推性问题通过调整 base(如从 10000 改到 1000000)或 NTK-aware 插值来缓解。

SwiGLU

SwiGLU 是现代 LLM(LLaMA、PaLM)的 FFN 层标准结构,用门控机制替代传统的 ReLU/GeLU。

SwiGLU(x)=(xW1SiLU(xW2))W3\text{SwiGLU}(x) = (xW_1 \odot \text{SiLU}(xW_2))W_3

结构包含三个线性层,中间隐藏维度通常是 8/3×dmodel8/3 \times d_{model}

class SwiGLU(nn.Module):
    """SwiGLU FFN"""
    def __init__(self, d_model, hidden_dim=None):
        super().__init__()
        if hidden_dim is None:
            # LLaMA 使用 8/3 * d_model,取最近的 256 倍数
            hidden_dim = int(8 * d_model / 3)
            hidden_dim = (hidden_dim + 255) // 256 * 256

        self.W1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.W2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.W3 = nn.Linear(hidden_dim, d_model, bias=False)
        self.act = nn.SiLU()

    def forward(self, x):
        # 门控分支
        gate = self.act(self.W1(x))
        # 值分支
        value = self.W2(x)
        # 逐元素乘积后投影
        return self.W3(gate * value)

# 验证
swiglu = SwiGLU(d_model=512)
x = torch.randn(2, 10, 512)
output = swiglu(x)
print(f"SwiGLU 输出形状: {output.shape}")
print(f"参数量: {sum(p.numel() for p in swiglu.parameters()):,}")

面试高频: SwiGLU 为什么比标准 FFN 效果好?参数量如何?

SwiGLU 引入门控机制,相当于让网络自适应选择激活路径,表达能力更强。GLU Variants Improve Transformer 论文中系统对比了各种门控结构。参数量是标准 FFN 的约 1.5 倍(3 个矩阵 vs 2 个),但 PaLM 实验表明在相同参数量预算下,SwiGLU 版本的模型效果更好。

KL 散度

KL 散度(Kullback-Leibler Divergence)衡量两个概率分布的差异,是 RLHF、知识蒸馏和变分推断的核心工具。

定义:对于分布 PPQQ

DKL(PQ)=iP(i)logP(i)Q(i)D_{KL}(P \| Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}

KL 散度非负,且当且仅当 P=QP = Q 时为 0。注意 DKL(PQ)DKL(QP)D_{KL}(P \| Q) \neq D_{KL}(Q \| P),即不对称。

数值示例:假设 PP 是学生模型输出,QQ 是教师模型输出:

tokenPP (学生)QQ (教师)PlogPQP\log\frac{P}{Q}贡献
”猫”0.50.70.5×log(0.5/0.7)0.5 \times \log(0.5/0.7)-0.168
”狗”0.30.20.3×log(0.3/0.2)0.3 \times \log(0.3/0.2)+0.122
”鸟”0.20.10.2×log(0.2/0.1)0.2 \times \log(0.2/0.1)+0.139
合计1.01.0DKL(PQ)D_{KL}(P\|Q)0.093

学生给 “猫” 的概率低于教师,这部分贡献了负值;给 “狗” 和 “鸟” 的概率高于教师,贡献了正值。KL 散度是所有 token 贡献的总和。

def kl_divergence(p, q, eps=1e-10):
    """
    计算 KL(P || Q)
    p, q: 概率分布,shape 相同
    """
    # 确保是概率分布
    p = p.clamp(min=eps)
    q = q.clamp(min=eps)

    # KL = sum(p * log(p / q))
    return torch.sum(p * torch.log(p / q))

# 数值示例验证
p = torch.tensor([0.5, 0.3, 0.2])
q = torch.tensor([0.7, 0.2, 0.1])

# 逐 token 验证
for i, token in enumerate(["", "", ""]):
    contrib = p[i] * torch.log(p[i] / q[i])
    print(f"{token}: P={p[i]:.1f}, Q={q[i]:.1f}, 贡献={contrib:.4f}")

kl_result = kl_divergence(p, q)
print(f"\n总 KL 散度: {kl_result:.4f}")

# 验证:相同分布 KL 为 0
p_same = torch.tensor([0.2, 0.3, 0.5])
q_same = torch.tensor([0.2, 0.3, 0.5])
print(f"相同分布 KL: {kl_divergence(p_same, q_same):.6f}")

# 验证不对称性
print(f"KL(P||Q) vs KL(Q||P): {kl_divergence(p, q):.4f} vs {kl_divergence(q, p):.4f}")

在实际工程中,常常需要计算 logits 之间的 KL 散度(如蒸馏场景),这时可以直接用 log_softmax 避免数值问题:

def kl_divergence_logits(student_logits, teacher_logits, temperature=1.0):
    """
    从 logits 计算 KL 散度(知识蒸馏场景)
    """
    student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1)
    teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1)

    # KL = sum(p * (log p - log q))
    kl = teacher_probs * (teacher_log_probs - student_log_probs)
    return torch.sum(kl, dim=-1).mean()

# 验证与 PyTorch 的一致性
student = torch.randn(2, 10)
teacher = torch.randn(2, 10)

my_kl = kl_divergence_logits(student, teacher)
torch_kl = nn.functional.kl_div(
    torch.log_softmax(student, dim=-1),
    torch.softmax(teacher, dim=-1),
    reduction='batchmean'
)
print(f"KL Div 误差: {(my_kl - torch_kl).abs():.2e}")

面试高频: Forward KL vs Reverse KL 的区别?RLHF 中为什么用 Reverse KL?

Forward KL DKL(PQ)D_{KL}(P \| Q) 倾向于让 QQ 覆盖 PP 的所有支持区域(mode-covering),即使某些区域概率很小。Reverse KL DKL(QP)D_{KL}(Q \| P) 则让 QQ 集中在 PP 的高概率区域(mode-seeking)。RLHF 中使用 Reverse KL 约束模型不要偏离 reference model 太远,避免模式坍塌(mode collapse)到单一 high-reward 序列。

极大似然估计(MLE)

MLE 是语言模型训练的理论基础。核心思想:选择使观测数据似然最大的模型参数。

对于语言模型,目标是最大化序列的联合概率:

P(x1,x2,...,xT)=t=1TP(xtx<t)P(x_1, x_2, ..., x_T) = \prod_{t=1}^T P(x_t | x_{<t})

取对数后转化为最小化负对数似然:

L=t=1TlogP(xtx<t)\mathcal{L} = -\sum_{t=1}^T \log P(x_t | x_{<t})

数值示例:假设 vocab_size=5,模型输出 logits,真实 token 是 “苹果”(索引 2):

tokenLogit概率 PPlogP-\log P
“香蕉”1.00.0902.41
”橙子”2.00.2451.40
”苹果”3.00.6650.41
”葡萄”0.50.0552.90
”西瓜”-1.00.0154.20

真实 token “苹果” 的概率是 0.665,该位置的损失是 log(0.665)0.41-\log(0.665) \approx 0.41。模型越自信(概率接近 1),损失越小;越不确定(概率接近 0),损失越大。

def cross_entropy_loss(logits, targets):
    """
    手动实现交叉熵损失(MLE 的具体形式)
    logits: (batch, seq_len, vocab_size)
    targets: (batch, seq_len)
    """
    batch_size, seq_len, vocab_size = logits.shape

    # 展平
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = targets.view(-1)

    # Softmax
    probs = torch.softmax(logits_flat, dim=-1)

    # 取目标位置的负对数概率
    log_probs = torch.log(probs + 1e-10)
    nll_loss = -log_probs[torch.arange(len(targets_flat)), targets_flat]

    return nll_loss.mean()

# 数值示例演示
logits_demo = torch.tensor([1.0, 2.0, 3.0, 0.5, -1.0])  # 5个token的logits
target_demo = torch.tensor(2)  # "苹果"的索引

probs = torch.softmax(logits_demo, dim=-1)
loss = -torch.log(probs[target_demo])

print("Token 概率分布:")
for i, (name, prob) in enumerate(zip(["香蕉", "橙子", "苹果", "葡萄", "西瓜"], probs)):
    marker = " <- 目标" if i == target_demo else ""
    print(f"  {name}: logit={logits_demo[i]:.1f}, prob={prob:.3f}, -logP={-torch.log(prob):.2f}{marker}")

print(f"\n该位置损失: {loss:.4f}")

# 验证与 PyTorch 一致
logits = torch.randn(2, 10, 50000)
targets = torch.randint(0, 50000, (2, 10))
my_loss = cross_entropy_loss(logits, targets)
torch_loss = nn.functional.cross_entropy(logits.view(-1, 50000), targets.view(-1))
assert torch.allclose(my_loss, torch_loss, atol=1e-4)
print(f"\nMLE Loss 验证通过: {my_loss:.4f}")

面试高频: 为什么语言模型用交叉熵损失?与 MLE 的关系?

交叉熵损失等价于 MLE 在分类问题中的具体实现。最小化交叉熵 = 最大化训练数据的似然。对于下一个 token 预测,交叉熵直接衡量模型输出的概率分布与真实分布的差异,梯度更新让正确 token 的概率增大。

交叉熵

交叉熵衡量两个概率分布的差异。在分类任务中,一个是模型的预测分布 PP,另一个是真实标签的 one-hot 分布 QQ

H(Q,P)=iQ(i)logP(i)H(Q, P) = -\sum_i Q(i) \log P(i)

对于 one-hot 标签(只有正确类别为 1,其余为 0),公式简化为 logP(ytrue)-\log P(y_{\text{true}}),这就是 MLE 中的负对数似然。

数值示例:三分类问题,模型输出 vs 真实标签:

类别模型预测 PP真实标签 QQ贡献 QlogP-Q \log P
“猫”0.700
”狗”0.200
”鸟”0.11log(0.1)=2.30-\log(0.1) = 2.30
交叉熵2.30

模型给正确类别 “鸟” 的概率只有 0.1,损失高达 2.30。如果模型更自信(P=0.9P=0.9),损失降至 0.10。

def cross_entropy_manual(pred_probs, true_labels):
    """
    手动实现交叉熵
    pred_probs: (batch, num_classes) 模型输出的概率
    true_labels: (batch,) 真实类别索引
    """
    batch_size = pred_probs.shape[0]
    # 取正确类别的概率
    correct_probs = pred_probs[torch.arange(batch_size), true_labels]
    # 负对数似然
    loss = -torch.log(correct_probs + 1e-10)
    return loss.mean()

# 数值示例验证
probs_demo = torch.tensor([0.7, 0.2, 0.1])  # 模型预测
true_label = torch.tensor(2)  # 真实类别是"鸟"

print("交叉熵计算过程:")
for i, name in enumerate(["", "", ""]):
    is_correct = "" if i == true_label else " "
    contrib = 0.0 if i != true_label else -torch.log(probs_demo[i])
    print(f"  {is_correct} {name}: P={probs_demo[i]:.1f}, 贡献={contrib:.4f}")

ce_loss = -torch.log(probs_demo[true_label])
print(f"\n总交叉熵损失: {ce_loss:.4f}")

# 对比:模型更自信的情况
confident_probs = torch.tensor([0.05, 0.05, 0.9])
confident_loss = -torch.log(confident_probs[true_label])
print(f"如果 P(鸟)=0.9,损失降至: {confident_loss:.4f}")

# 验证与 PyTorch 一致
pred = torch.tensor([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
labels = torch.tensor([2, 1])

my_ce = cross_entropy_manual(pred, labels)
torch_ce = nn.functional.cross_entropy(torch.log(pred), labels)
print(f"\n交叉熵验证通过: {my_ce:.4f}")

面试高频: 交叉熵与 KL 散度的关系?

交叉熵 H(Q,P)=H(Q)+DKL(QP)H(Q, P) = H(Q) + D_{KL}(Q \| P),其中 H(Q)H(Q) 是真实分布的熵(训练时固定)。因此最小化交叉熵等价于最小化 DKL(QP)D_{KL}(Q \| P),即让模型分布逼近真实分布。

反向传播与梯度下降

反向传播是神经网络训练的核心算法,基于链式法则高效计算梯度。

class SimpleNN(nn.Module):
    """简单两层网络,演示反向传播"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        self.z1 = self.fc1(x)        # 第一层线性
        self.a1 = self.relu(self.z1) # ReLU激活
        self.z2 = self.fc2(self.a1)  # 第二层线性
        return self.z2

    def backward(self, x, y, lr=0.01):
        """手动反向传播"""
        batch_size = x.shape[0]

        # 前向
        output = self.forward(x)
        loss = nn.MSELoss()(output, y)

        # 反向:输出层梯度
        dz2 = (output - y) / batch_size  # dL/dz2
        dw2 = torch.matmul(self.a1.t(), dz2)  # dL/dW2 = a1^T @ dz2
        db2 = dz2.sum(dim=0)  # dL/db2

        # 反向:隐藏层梯度
        da1 = torch.matmul(dz2, self.fc2.weight)  # dL/da1 = dz2 @ W2^T
        dz1 = da1 * (self.z1 > 0).float()  # dL/dz1 = da1 * ReLU'(z1)
        dw1 = torch.matmul(x.t(), dz1)  # dL/dW1 = x^T @ dz1
        db1 = dz1.sum(dim=0)  # dL/db1

        # 参数更新(SGD)
        with torch.no_grad():
            self.fc1.weight -= lr * dw1.t()
            self.fc1.bias -= lr * db1
            self.fc2.weight -= lr * dw2.t()
            self.fc2.bias -= lr * db2

        return loss.item()

# 对比手动和自动求导
model = SimpleNN(10, 20, 1)
x = torch.randn(32, 10)
y = torch.randn(32, 1)

# PyTorch 自动求导
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
output = model(x)
loss1 = nn.MSELoss()(output, y)
optimizer.zero_grad()
loss1.backward()

# 保存梯度
grad_auto = model.fc1.weight.grad.clone()

# 重新初始化,手动求导
model = SimpleNN(10, 20, 1)
loss2 = model.backward(x, y, lr=0.01)

print(f"自动求导完成")
print(f"手动求导 Loss: {loss2:.6f}")

面试高频: 链式法则的矩阵形式如何理解?梯度消失/爆炸的本质?

反向传播时,梯度通过雅可比矩阵连乘传递。LW1=Lz2z2a1a1z1z1W1\frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial W_1}。梯度消失源于激活函数导数 < 1(如 Sigmoid 最大 0.25),多层连乘后指数级衰减。梯度爆炸则来自权重矩阵特征值 > 1,可用梯度裁剪、残差连接、LayerNorm 缓解。

复杂度速查表

算子时间复杂度空间复杂度面试频率
Sigmoid/GELU/SiLUO(n)O(1)⭐⭐
LayerNorm/RMSNormO(batch × seq × hidden)O(hidden)⭐⭐⭐⭐
SoftmaxO(n)O(n)⭐⭐⭐
Self-AttentionO(batch × n² × d)O(n²)⭐⭐⭐⭐⭐
Multi-Head AttentionO(batch × h × n² × d_h)O(h × n²)⭐⭐⭐⭐⭐
KV CacheO(1) 每步O(batch × h × max_len × d_h)⭐⭐⭐⭐⭐
RoPEO(n × d)O(max_len × d/2)⭐⭐⭐⭐
SwiGLUO(batch × seq × d × hidden)O(hidden)⭐⭐⭐
KL DivergenceO(vocab)O(vocab)⭐⭐⭐⭐

说明: n 为序列长度,d 为模型维度,h 为 head 数,d_h = d/h。

掌握这些算子的实现细节,面试中遇到”手撕 Attention”或”讲清楚 LayerNorm”这类问题就能从容应对。代码建议收藏,需要时直接查阅。



Next Post
Activation Checkpointing:用时间换显存的艺术