大模型的核心能力来自这些基础算子的组合。面试中手撕 Attention、写 LayerNorm 是常规操作,但很多人只背了公式,没真正动手实现过。本文从激活函数到注意力机制,逐个拆解 PyTorch 版的精简实现,标注面试高频考点。
激活函数
Sigmoid
Sigmoid 将任意实数映射到 (0, 1),适合表示概率。现代 LLM 已很少直接使用,但在门控机制(如 LSTM、GRU)中仍有应用。
导数有一个优美的性质:,这使得反向传播计算非常高效。
import torch
import torch.nn as nn
class Sigmoid(nn.Module):
"""Sigmoid 激活函数"""
def forward(self, x):
return 1 / (1 + torch.exp(-x))
# 验证
x = torch.randn(3, 4)
my_sigmoid = Sigmoid()
torch_sigmoid = nn.Sigmoid()
assert torch.allclose(my_sigmoid(x), torch_sigmoid(x), atol=1e-6)
print(f"最大误差: {(my_sigmoid(x) - torch_sigmoid(x)).abs().max():.2e}")
面试高频: Sigmoid 的梯度消失问题是如何产生的?
当输入绝对值较大时(x > 5 或 x < -5),Sigmoid 输出接近 0 或 1,导数接近 0。多层堆叠后,梯度连乘导致梯度消失。这也是为什么深度网络更偏好 ReLU 及其变体。
GELU
GELU(Gaussian Error Linear Unit)是 BERT、GPT 系列的标准激活。相比 ReLU 的硬截断,GELU 提供平滑的非线性变换。
精确计算依赖误差函数 erf:
工程上更常用 tanh 近似(Hendrycks 2016):
class GELU(nn.Module):
"""GELU 激活函数(tanh 近似版)"""
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
# 验证
x = torch.randn(3, 4)
my_gelu = GELU()
torch_gelu = nn.GELU(approximate='tanh')
assert torch.allclose(my_gelu(x), torch_gelu(x), atol=1e-5)
print(f"最大误差: {(my_gelu(x) - torch_gelu(x)).abs().max():.2e}")
面试高频: GELU 与 ReLU 的本质区别是什么?
ReLU 是确定性的硬阈值(x < 0 则输出 0),GELU 则是概率性的软过渡。从随机正则化角度理解:GELU 等价于对输入施加 Dropout 的期望输出。这种平滑性让优化更稳定,尤其在深层网络中。
SiLU / Swish
SiLU(Sigmoid Linear Unit),也叫 Swish,是 LLaMA、Mistral 等现代 LLM 的首选激活。Google Brain 2017 年提出,发现其性能在深层网络上优于 ReLU。
结构上是输入与 sigmoid 门的逐元素乘积,实现了自门控(self-gated)机制。
class SiLU(nn.Module):
"""SiLU / Swish 激活函数"""
def forward(self, x):
return x * torch.sigmoid(x)
# 验证
x = torch.randn(3, 4)
my_silu = SiLU()
torch_silu = nn.SiLU()
assert torch.allclose(my_silu(x), torch_silu(x), atol=1e-6)
print(f"最大误差: {(my_silu(x) - torch_silu(x)).abs().max():.2e}")
面试高频: LLaMA 为什么选 SiLU 而不是 GELU?
两者性能差距不大,但 SiLU 计算量略小(没有立方运算)。PaLM、LLaMA-2 的实验表明,在足够大的模型和数据上,SiLU 的收敛速度稍快。这属于工程实践中的细微取舍。
归一化层
BatchNorm vs LayerNorm vs RMSNorm
三种归一化的核心差异在于沿着哪个维度计算统计量:
| 归一化 | 计算维度 | 适用场景 | 大模型使用 |
|---|---|---|---|
| BatchNorm | (N, H, W) 的 batch 维 | CNN | 几乎不用 |
| LayerNorm | (H) 的特征维 | Transformer | 标配 |
| RMSNorm | (H) 的特征维(仅 std) | 现代 LLM | LLaMA/Qwen 首选 |
BatchNorm 依赖 batch 统计量,在序列长度变长的 Transformer 中不稳定。LayerNorm 和 RMSNorm 对每个样本独立计算,更适合变长序列。
LayerNorm
LayerNorm 对每个样本的所有特征做归一化,消除样本间的分布差异。
和 是可学习的缩放和平移参数。
数值示例:假设有一个 4 维输入向量,:
| 步骤 | 计算 | 结果 |
|---|---|---|
| 原始输入 | [1.0, 2.0, 3.0, 4.0] | |
| 均值 | 2.5 | |
| 方差 | 1.25 | |
| 归一化 | [-1.34, -0.45, 0.45, 1.34] | |
| 输出 | [-1.34, -0.45, 0.45, 1.34] |
归一化后均值为 0,方差为 1。实际训练中 和 会被学习成适合当前层的分布。
class LayerNorm(nn.Module):
"""Layer Normalization"""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(normalized_shape))
self.beta = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
# x: (..., normalized_shape)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
# 验证数值示例
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
ln = LayerNorm(4)
with torch.no_grad():
# 手动计算验证
mean = x.mean() # 2.5
var = x.var(unbiased=False) # 1.25
expected = (x - mean) / torch.sqrt(var + 1e-5)
print(f"均值: {mean:.2f}, 方差: {var:.2f}")
print(f"预期归一化: {expected}")
print(f"实际输出: {ln(x)}")
# 与官方实现对比
ln = LayerNorm(64)
torch_ln = nn.LayerNorm(64)
x = torch.randn(2, 10, 64)
with torch.no_grad():
torch_ln.weight.copy_(ln.gamma)
torch_ln.bias.copy_(ln.beta)
assert torch.allclose(ln(x), torch_ln(x), atol=1e-5)
print("LayerNorm 验证通过")
面试高频: Transformer 为什么用 LayerNorm 而不是 BatchNorm?
- 序列长度可变,BatchNorm 的 batch 统计量不稳定
- 推理时 BatchNorm 需要维护 running statistics,LayerNorm 训练和测试行为一致
- LayerNorm 对样本独立计算,更适合自回归生成场景
RMSNorm
RMSNorm(Root Mean Square LayerNorm)去掉了 LayerNorm 的 mean 计算,只保留 RMS 缩放。LLaMA-2 技术报告提到这能略微提升训练稳定性。
数值对比:使用与 LayerNorm 示例相同的输入 [1.0, 2.0, 3.0, 4.0]:
| 归一化 | 计算 | 输出结果 |
|---|---|---|
| LayerNorm | [-1.34, -0.45, 0.45, 1.34] | |
| RMSNorm | [0.37, 0.73, 1.10, 1.46] |
注意 RMSNorm 不中心化为 0,而是保留原始数值的相对大小。对于 [1,2,3,4] 这种全正输入,RMSNorm 输出也都是正数。
class RMSNorm(nn.Module):
"""RMS Normalization"""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(normalized_shape))
def forward(self, x):
# 只计算 RMS,不做中心化
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
x_norm = x / rms
return self.weight * x_norm
# 验证数值示例
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
rms = RMSNorm(4)
with torch.no_grad():
rms_val = torch.sqrt(torch.mean(x ** 2)) # sqrt(30/4) = 2.74
expected = x / rms_val
print(f"RMS: {rms_val:.2f}")
print(f"预期输出: {expected}")
print(f"实际输出: {rms(x)}")
# 验证与手动计算一致
rms = RMSNorm(64)
x = torch.randn(2, 10, 64)
output = rms(x)
expected_rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-5)
expected = x / expected_rms * rms.weight
assert torch.allclose(output, expected, atol=1e-5)
print("RMSNorm 验证通过")
面试高频: RMSNorm 去掉 mean 计算有什么好处?
- 计算量减少约 30%(省去一次求和)
- 对绝对位置信息更敏感,某些任务上有轻微精度提升
- 在 LLaMA 系列中被验证与 LayerNorm 效果相当甚至更优
Softmax
基础 Softmax
Softmax 将 logits 转换为概率分布,是注意力机制和分类输出的标配。
def softmax(x, dim=-1):
"""基础 Softmax"""
exp_x = torch.exp(x)
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
# 数值稳定性测试
x = torch.tensor([1000.0, 1001.0, 1002.0])
try:
result = softmax(x)
print(f"结果: {result}")
except Exception as e:
print(f"溢出: {e}")
Safe Softmax
直接计算 会导致数值溢出(x > 709 时 double 也会溢出)。解决方案是减去最大值:
数学上等价,因为分子分母同时除以 不改变结果。
def safe_softmax(x, dim=-1):
"""数值稳定的 Softmax"""
x_max = torch.max(x, dim=dim, keepdim=True).values
exp_x = torch.exp(x - x_max)
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
# 验证
x = torch.tensor([1000.0, 1001.0, 1002.0])
result = safe_softmax(x)
print(f"Safe Softmax 结果: {result}")
print(f"和为 1: {result.sum():.6f}")
# 与官方对比
x = torch.randn(3, 4)
assert torch.allclose(safe_softmax(x), torch.softmax(x, dim=-1), atol=1e-5)
print("Safe Softmax 验证通过")
面试高频: Softmax 的数值不稳定问题如何解决?时间复杂度是多少?
减去最大值保证指数运算不溢出,结果数学等价。时间复杂度 O(n),空间复杂度 O(n)(需要存储中间结果)。实际工程中还经常用 online softmax 进一步减少内存访问。
注意力机制
Self-Attention
Self-Attention 是 Transformer 的核心,让序列中的每个位置都能关注其他所有位置。
class SelfAttention(nn.Module):
"""基础 Self-Attention"""
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.scale = torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
def forward(self, x, mask=None):
# x: (batch, seq_len, d_model)
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# 应用 mask(如有)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights
# 验证
attn = SelfAttention(64)
x = torch.randn(2, 10, 64)
output, weights = attn(x)
print(f"输出形状: {output.shape}, 注意力权重形状: {weights.shape}")
面试高频: 为什么要除以 ?
当 较大时, 的方差也会增大,导致 softmax 进入梯度极小的饱和区。缩放因子 将方差控制在合理范围,保证梯度流动。这是 Attention Is All You Need 论文中的关键设计。
Cross-Attention
Cross-Attention 的 Query 来自一个序列,Key/Value 来自另一个序列,用于 Encoder-Decoder 架构和多模态融合。
class CrossAttention(nn.Module):
"""Cross Attention: Q from x, KV from context"""
def __init__(self, d_model):
super().__init__()
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.scale = torch.sqrt(torch.tensor(d_model))
def forward(self, x, context):
# x: 目标序列 (batch, tgt_len, d_model)
# context: 源序列 (batch, src_len, d_model)
Q = self.W_q(x)
K = self.W_k(context)
V = self.W_v(context)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output
# 验证:用于机器翻译场景
cross_attn = CrossAttention(64)
target = torch.randn(2, 5, 64) # decoder 当前状态
source = torch.randn(2, 10, 64) # encoder 输出
output = cross_attn(target, source)
print(f"Cross-Attention 输出形状: {output.shape}")
面试高频: Cross-Attention 的典型应用场景?
- 机器翻译:decoder 关注 encoder 输出的源语言表示
- 多模态:视觉特征作为 KV,文本 Query 关注图像区域
- T5 等 Encoder-Decoder 模型:连接编码器和解码器
Multi-Head Attention
Multi-Head Attention 将输入投影到多个子空间分别计算注意力,增强模型表达能力。
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
# (batch, seq, d_model) -> (batch, num_heads, seq, d_k)
batch, seq_len, _ = x.shape
return x.view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# 线性投影并分头
Q = self.split_heads(self.W_q(x))
K = self.split_heads(self.W_k(x))
V = self.split_heads(self.W_v(x))
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 合并多头并投影
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
return self.W_o(attn_output)
# 验证
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.randn(2, 10, 64)
output = mha(x)
print(f"MHA 输出形状: {output.shape}")
面试高频: 多头的直觉理解是什么?参数量如何计算?
每个 head 关注不同的特征子空间(语法 vs 语义、局部 vs 全局)。参数量:4 个投影矩阵 ,每个 ,总参数量 。
KV Cache
KV Cache 是自回归生成的核心优化,缓存历史 KV 避免重复计算,将复杂度从 降到 。
class KVCacheAttention(nn.Module):
"""带 KV Cache 的 Attention(推理优化)"""
def __init__(self, d_model, num_heads, max_batch_size, max_seq_len):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# 预分配 KV Cache
self.register_buffer(
'k_cache',
torch.zeros(max_batch_size, num_heads, max_seq_len, self.d_k)
)
self.register_buffer(
'v_cache',
torch.zeros(max_batch_size, num_heads, max_seq_len, self.d_k)
)
self.cache_len = 0
def forward(self, x, start_pos=0):
"""
x: (batch, 1, d_model) 只输入当前 token
start_pos: 当前 token 在序列中的位置
"""
batch, seq_len, _ = x.shape
Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 更新 cache
self.k_cache[:batch, :, start_pos:start_pos+seq_len] = K
self.v_cache[:batch, :, start_pos:start_pos+seq_len] = V
# 使用全部缓存计算注意力
K_full = self.k_cache[:batch, :, :start_pos+seq_len]
V_full = self.v_cache[:batch, :, :start_pos+seq_len]
scores = torch.matmul(Q, K_full.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V_full)
output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
return self.W_o(output)
# 演示逐步生成
cache_attn = KVCacheAttention(64, 8, max_batch_size=2, max_seq_len=100)
# 模拟生成过程
for i in range(5):
token = torch.randn(2, 1, 64) # 每次只输入一个新 token
output = cache_attn(token, start_pos=i)
print(f"步骤 {i}: 输出形状 {output.shape}")
面试高频: 为什么只缓存 KV 不缓存 Q?显存占用如何计算?
自回归生成时,当前 token 只作为 Query,需要与所有历史 KV 计算注意力。Query 不需要保留,用完即弃。显存占用:,对于 7B 模型、4K 上下文约需 2GB。
RoPE 旋转位置编码
RoPE(Rotary Position Embedding)通过旋转矩阵注入位置信息,是 LLaMA、Qwen 等现代 LLM 的标准配置。
核心思想:将相邻维度组成一对,应用二维旋转矩阵,旋转角度与位置索引成正比。
其中 , 是位置索引。
class RoPE(nn.Module):
"""旋转位置编码 (Rotary Position Embedding)"""
def __init__(self, d_model, max_seq_len=2048, base=10000):
super().__init__()
self.d_model = d_model
# 计算旋转角度 theta_i = base^(-2i/d)
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer('inv_freq', inv_freq)
# 预计算位置对应的 cos/sin
positions = torch.arange(max_seq_len)
angles = torch.outer(positions, inv_freq) # (max_seq_len, d_model/2)
self.register_buffer('cos', torch.cos(angles))
self.register_buffer('sin', torch.sin(angles))
def forward(self, x, seq_len):
"""
x: (batch, num_heads, seq_len, d_k)
应用旋转位置编码
"""
# 分离奇偶维度
x1 = x[..., ::2] # 偶数维
x2 = x[..., 1::2] # 奇数维
# 应用旋转
cos = self.cos[:seq_len].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len].unsqueeze(0).unsqueeze(0)
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
# 交错合并
output = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
return output
# 验证
rope = RoPE(d_model=64)
x = torch.randn(2, 8, 10, 64) # (batch, heads, seq, d_k)
output = rope(x, seq_len=10)
print(f"RoPE 输出形状: {output.shape}")
# 验证相对位置特性:相同相对位置的旋转角度相同
x1 = torch.randn(1, 1, 1, 64)
x2 = torch.randn(1, 1, 1, 64)
# 位置 5 的 token 和位置 3 的 token
rope_5_3 = rope(torch.cat([x1, x2], dim=2), seq_len=2)
# 位置 7 的 token 和位置 5 的 token
rope_7_5 = rope(torch.cat([x1, x2], dim=2), seq_len=2)
# 相对位置都是 2,点积应该相同
dot_5_3 = torch.sum(rope_5_3[0, 0, 0] * rope_5_3[0, 0, 1])
dot_7_5 = torch.sum(rope_7_5[0, 0, 0] * rope_7_5[0, 0, 1])
print(f"相对位置编码一致性: {torch.allclose(dot_5_3, dot_7_5, atol=1e-5)}")
面试高频: RoPE 相比绝对位置编码的优势?外推性问题如何解决?
RoPE 天然支持相对位置:两个 token 的注意力分数只依赖它们的相对距离,与绝对位置无关。这带来更好的长度外推性。训练时未见过的长序列上,RoPE 表现优于正弦/可学习位置编码。外推性问题通过调整 base(如从 10000 改到 1000000)或 NTK-aware 插值来缓解。
SwiGLU
SwiGLU 是现代 LLM(LLaMA、PaLM)的 FFN 层标准结构,用门控机制替代传统的 ReLU/GeLU。
结构包含三个线性层,中间隐藏维度通常是 。
class SwiGLU(nn.Module):
"""SwiGLU FFN"""
def __init__(self, d_model, hidden_dim=None):
super().__init__()
if hidden_dim is None:
# LLaMA 使用 8/3 * d_model,取最近的 256 倍数
hidden_dim = int(8 * d_model / 3)
hidden_dim = (hidden_dim + 255) // 256 * 256
self.W1 = nn.Linear(d_model, hidden_dim, bias=False)
self.W2 = nn.Linear(d_model, hidden_dim, bias=False)
self.W3 = nn.Linear(hidden_dim, d_model, bias=False)
self.act = nn.SiLU()
def forward(self, x):
# 门控分支
gate = self.act(self.W1(x))
# 值分支
value = self.W2(x)
# 逐元素乘积后投影
return self.W3(gate * value)
# 验证
swiglu = SwiGLU(d_model=512)
x = torch.randn(2, 10, 512)
output = swiglu(x)
print(f"SwiGLU 输出形状: {output.shape}")
print(f"参数量: {sum(p.numel() for p in swiglu.parameters()):,}")
面试高频: SwiGLU 为什么比标准 FFN 效果好?参数量如何?
SwiGLU 引入门控机制,相当于让网络自适应选择激活路径,表达能力更强。GLU Variants Improve Transformer 论文中系统对比了各种门控结构。参数量是标准 FFN 的约 1.5 倍(3 个矩阵 vs 2 个),但 PaLM 实验表明在相同参数量预算下,SwiGLU 版本的模型效果更好。
KL 散度
KL 散度(Kullback-Leibler Divergence)衡量两个概率分布的差异,是 RLHF、知识蒸馏和变分推断的核心工具。
定义:对于分布 和
KL 散度非负,且当且仅当 时为 0。注意 ,即不对称。
数值示例:假设 是学生模型输出, 是教师模型输出:
| token | (学生) | (教师) | 贡献 | |
|---|---|---|---|---|
| ”猫” | 0.5 | 0.7 | -0.168 | |
| ”狗” | 0.3 | 0.2 | +0.122 | |
| ”鸟” | 0.2 | 0.1 | +0.139 | |
| 合计 | 1.0 | 1.0 | 0.093 |
学生给 “猫” 的概率低于教师,这部分贡献了负值;给 “狗” 和 “鸟” 的概率高于教师,贡献了正值。KL 散度是所有 token 贡献的总和。
def kl_divergence(p, q, eps=1e-10):
"""
计算 KL(P || Q)
p, q: 概率分布,shape 相同
"""
# 确保是概率分布
p = p.clamp(min=eps)
q = q.clamp(min=eps)
# KL = sum(p * log(p / q))
return torch.sum(p * torch.log(p / q))
# 数值示例验证
p = torch.tensor([0.5, 0.3, 0.2])
q = torch.tensor([0.7, 0.2, 0.1])
# 逐 token 验证
for i, token in enumerate(["猫", "狗", "鸟"]):
contrib = p[i] * torch.log(p[i] / q[i])
print(f"{token}: P={p[i]:.1f}, Q={q[i]:.1f}, 贡献={contrib:.4f}")
kl_result = kl_divergence(p, q)
print(f"\n总 KL 散度: {kl_result:.4f}")
# 验证:相同分布 KL 为 0
p_same = torch.tensor([0.2, 0.3, 0.5])
q_same = torch.tensor([0.2, 0.3, 0.5])
print(f"相同分布 KL: {kl_divergence(p_same, q_same):.6f}")
# 验证不对称性
print(f"KL(P||Q) vs KL(Q||P): {kl_divergence(p, q):.4f} vs {kl_divergence(q, p):.4f}")
在实际工程中,常常需要计算 logits 之间的 KL 散度(如蒸馏场景),这时可以直接用 log_softmax 避免数值问题:
def kl_divergence_logits(student_logits, teacher_logits, temperature=1.0):
"""
从 logits 计算 KL 散度(知识蒸馏场景)
"""
student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1)
teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1)
# KL = sum(p * (log p - log q))
kl = teacher_probs * (teacher_log_probs - student_log_probs)
return torch.sum(kl, dim=-1).mean()
# 验证与 PyTorch 的一致性
student = torch.randn(2, 10)
teacher = torch.randn(2, 10)
my_kl = kl_divergence_logits(student, teacher)
torch_kl = nn.functional.kl_div(
torch.log_softmax(student, dim=-1),
torch.softmax(teacher, dim=-1),
reduction='batchmean'
)
print(f"KL Div 误差: {(my_kl - torch_kl).abs():.2e}")
面试高频: Forward KL vs Reverse KL 的区别?RLHF 中为什么用 Reverse KL?
Forward KL 倾向于让 覆盖 的所有支持区域(mode-covering),即使某些区域概率很小。Reverse KL 则让 集中在 的高概率区域(mode-seeking)。RLHF 中使用 Reverse KL 约束模型不要偏离 reference model 太远,避免模式坍塌(mode collapse)到单一 high-reward 序列。
极大似然估计(MLE)
MLE 是语言模型训练的理论基础。核心思想:选择使观测数据似然最大的模型参数。
对于语言模型,目标是最大化序列的联合概率:
取对数后转化为最小化负对数似然:
数值示例:假设 vocab_size=5,模型输出 logits,真实 token 是 “苹果”(索引 2):
| token | Logit | 概率 | |
|---|---|---|---|
| “香蕉” | 1.0 | 0.090 | 2.41 |
| ”橙子” | 2.0 | 0.245 | 1.40 |
| ”苹果” | 3.0 | 0.665 | 0.41 |
| ”葡萄” | 0.5 | 0.055 | 2.90 |
| ”西瓜” | -1.0 | 0.015 | 4.20 |
真实 token “苹果” 的概率是 0.665,该位置的损失是 。模型越自信(概率接近 1),损失越小;越不确定(概率接近 0),损失越大。
def cross_entropy_loss(logits, targets):
"""
手动实现交叉熵损失(MLE 的具体形式)
logits: (batch, seq_len, vocab_size)
targets: (batch, seq_len)
"""
batch_size, seq_len, vocab_size = logits.shape
# 展平
logits_flat = logits.view(-1, vocab_size)
targets_flat = targets.view(-1)
# Softmax
probs = torch.softmax(logits_flat, dim=-1)
# 取目标位置的负对数概率
log_probs = torch.log(probs + 1e-10)
nll_loss = -log_probs[torch.arange(len(targets_flat)), targets_flat]
return nll_loss.mean()
# 数值示例演示
logits_demo = torch.tensor([1.0, 2.0, 3.0, 0.5, -1.0]) # 5个token的logits
target_demo = torch.tensor(2) # "苹果"的索引
probs = torch.softmax(logits_demo, dim=-1)
loss = -torch.log(probs[target_demo])
print("Token 概率分布:")
for i, (name, prob) in enumerate(zip(["香蕉", "橙子", "苹果", "葡萄", "西瓜"], probs)):
marker = " <- 目标" if i == target_demo else ""
print(f" {name}: logit={logits_demo[i]:.1f}, prob={prob:.3f}, -logP={-torch.log(prob):.2f}{marker}")
print(f"\n该位置损失: {loss:.4f}")
# 验证与 PyTorch 一致
logits = torch.randn(2, 10, 50000)
targets = torch.randint(0, 50000, (2, 10))
my_loss = cross_entropy_loss(logits, targets)
torch_loss = nn.functional.cross_entropy(logits.view(-1, 50000), targets.view(-1))
assert torch.allclose(my_loss, torch_loss, atol=1e-4)
print(f"\nMLE Loss 验证通过: {my_loss:.4f}")
面试高频: 为什么语言模型用交叉熵损失?与 MLE 的关系?
交叉熵损失等价于 MLE 在分类问题中的具体实现。最小化交叉熵 = 最大化训练数据的似然。对于下一个 token 预测,交叉熵直接衡量模型输出的概率分布与真实分布的差异,梯度更新让正确 token 的概率增大。
交叉熵
交叉熵衡量两个概率分布的差异。在分类任务中,一个是模型的预测分布 ,另一个是真实标签的 one-hot 分布 。
对于 one-hot 标签(只有正确类别为 1,其余为 0),公式简化为 ,这就是 MLE 中的负对数似然。
数值示例:三分类问题,模型输出 vs 真实标签:
| 类别 | 模型预测 | 真实标签 | 贡献 |
|---|---|---|---|
| “猫” | 0.7 | 0 | 0 |
| ”狗” | 0.2 | 0 | 0 |
| ”鸟” | 0.1 | 1 | |
| 交叉熵 | 2.30 |
模型给正确类别 “鸟” 的概率只有 0.1,损失高达 2.30。如果模型更自信(),损失降至 0.10。
def cross_entropy_manual(pred_probs, true_labels):
"""
手动实现交叉熵
pred_probs: (batch, num_classes) 模型输出的概率
true_labels: (batch,) 真实类别索引
"""
batch_size = pred_probs.shape[0]
# 取正确类别的概率
correct_probs = pred_probs[torch.arange(batch_size), true_labels]
# 负对数似然
loss = -torch.log(correct_probs + 1e-10)
return loss.mean()
# 数值示例验证
probs_demo = torch.tensor([0.7, 0.2, 0.1]) # 模型预测
true_label = torch.tensor(2) # 真实类别是"鸟"
print("交叉熵计算过程:")
for i, name in enumerate(["猫", "狗", "鸟"]):
is_correct = "✓" if i == true_label else " "
contrib = 0.0 if i != true_label else -torch.log(probs_demo[i])
print(f" {is_correct} {name}: P={probs_demo[i]:.1f}, 贡献={contrib:.4f}")
ce_loss = -torch.log(probs_demo[true_label])
print(f"\n总交叉熵损失: {ce_loss:.4f}")
# 对比:模型更自信的情况
confident_probs = torch.tensor([0.05, 0.05, 0.9])
confident_loss = -torch.log(confident_probs[true_label])
print(f"如果 P(鸟)=0.9,损失降至: {confident_loss:.4f}")
# 验证与 PyTorch 一致
pred = torch.tensor([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
labels = torch.tensor([2, 1])
my_ce = cross_entropy_manual(pred, labels)
torch_ce = nn.functional.cross_entropy(torch.log(pred), labels)
print(f"\n交叉熵验证通过: {my_ce:.4f}")
面试高频: 交叉熵与 KL 散度的关系?
交叉熵 ,其中 是真实分布的熵(训练时固定)。因此最小化交叉熵等价于最小化 ,即让模型分布逼近真实分布。
反向传播与梯度下降
反向传播是神经网络训练的核心算法,基于链式法则高效计算梯度。
class SimpleNN(nn.Module):
"""简单两层网络,演示反向传播"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
self.z1 = self.fc1(x) # 第一层线性
self.a1 = self.relu(self.z1) # ReLU激活
self.z2 = self.fc2(self.a1) # 第二层线性
return self.z2
def backward(self, x, y, lr=0.01):
"""手动反向传播"""
batch_size = x.shape[0]
# 前向
output = self.forward(x)
loss = nn.MSELoss()(output, y)
# 反向:输出层梯度
dz2 = (output - y) / batch_size # dL/dz2
dw2 = torch.matmul(self.a1.t(), dz2) # dL/dW2 = a1^T @ dz2
db2 = dz2.sum(dim=0) # dL/db2
# 反向:隐藏层梯度
da1 = torch.matmul(dz2, self.fc2.weight) # dL/da1 = dz2 @ W2^T
dz1 = da1 * (self.z1 > 0).float() # dL/dz1 = da1 * ReLU'(z1)
dw1 = torch.matmul(x.t(), dz1) # dL/dW1 = x^T @ dz1
db1 = dz1.sum(dim=0) # dL/db1
# 参数更新(SGD)
with torch.no_grad():
self.fc1.weight -= lr * dw1.t()
self.fc1.bias -= lr * db1
self.fc2.weight -= lr * dw2.t()
self.fc2.bias -= lr * db2
return loss.item()
# 对比手动和自动求导
model = SimpleNN(10, 20, 1)
x = torch.randn(32, 10)
y = torch.randn(32, 1)
# PyTorch 自动求导
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
output = model(x)
loss1 = nn.MSELoss()(output, y)
optimizer.zero_grad()
loss1.backward()
# 保存梯度
grad_auto = model.fc1.weight.grad.clone()
# 重新初始化,手动求导
model = SimpleNN(10, 20, 1)
loss2 = model.backward(x, y, lr=0.01)
print(f"自动求导完成")
print(f"手动求导 Loss: {loss2:.6f}")
面试高频: 链式法则的矩阵形式如何理解?梯度消失/爆炸的本质?
反向传播时,梯度通过雅可比矩阵连乘传递。。梯度消失源于激活函数导数 < 1(如 Sigmoid 最大 0.25),多层连乘后指数级衰减。梯度爆炸则来自权重矩阵特征值 > 1,可用梯度裁剪、残差连接、LayerNorm 缓解。
复杂度速查表
| 算子 | 时间复杂度 | 空间复杂度 | 面试频率 |
|---|---|---|---|
| Sigmoid/GELU/SiLU | O(n) | O(1) | ⭐⭐ |
| LayerNorm/RMSNorm | O(batch × seq × hidden) | O(hidden) | ⭐⭐⭐⭐ |
| Softmax | O(n) | O(n) | ⭐⭐⭐ |
| Self-Attention | O(batch × n² × d) | O(n²) | ⭐⭐⭐⭐⭐ |
| Multi-Head Attention | O(batch × h × n² × d_h) | O(h × n²) | ⭐⭐⭐⭐⭐ |
| KV Cache | O(1) 每步 | O(batch × h × max_len × d_h) | ⭐⭐⭐⭐⭐ |
| RoPE | O(n × d) | O(max_len × d/2) | ⭐⭐⭐⭐ |
| SwiGLU | O(batch × seq × d × hidden) | O(hidden) | ⭐⭐⭐ |
| KL Divergence | O(vocab) | O(vocab) | ⭐⭐⭐⭐ |
说明: n 为序列长度,d 为模型维度,h 为 head 数,d_h = d/h。
掌握这些算子的实现细节,面试中遇到”手撕 Attention”或”讲清楚 LayerNorm”这类问题就能从容应对。代码建议收藏,需要时直接查阅。