从GQA到MLA:还能再省吗?
上一篇我们讲了GQA,通过让多个Q head共享KV来减少KV Cache。BuddyGPT 0.3B用16Q/8KV的配置,比MHA节省了50%的KV Cache。
但DeepSeek团队问了一个更根本的问题:KV Cache存的一定要是完整的Key和Value吗?能不能存一个更小的"压缩版"?
这就是MLA(Multi-head Latent Attention)的出发点。
MLA核心思想:低秩压缩
MLA的核心思路用三句话概括:
- 训练时:先把hidden state压缩到一个低维的latent向量,再从latent展开为完整的Q/K/V
- 推理时:KV Cache只需要存低维的latent,而不是完整的K和V
- 效果:KV Cache可以压缩到原来的1/6甚至更小,而质量几乎不损失
用图来理解:
传统GQA的KV Cache:
┌───────────────────────────────────────┐
│ 缓存完整的K和V │
│ K: [n_kv_heads × head_dim] = 8×64 │ = 512维
│ V: [n_kv_heads × head_dim] = 8×64 │ = 512维
│ 总计: 1024维/token │
└───────────────────────────────────────┘
MLA的KV Cache:
┌───────────────────────────────────────┐
│ 只缓存低维latent │
│ kv_latent: [kv_lora_rank] = 16维 │
│ k_rope: [qk_rope_head_dim] = 24维 │ (位置编码部分需要单独存)
│ 总计: 40维/token │ ← 比GQA小25倍!
└───────────────────────────────────────┘
MLA的数据流
MLA的完整数据流比GQA复杂,但逻辑很清晰。我们分Q和KV两条路径来看:
hidden_states [B, L, hidden_size=1024]
│
┌───────────────┴───────────────┐
│ │
┌─────v─────┐ ┌─────v─────┐
│ Q 路径 │ │ KV 路径 │
└─────┬─────┘ └─────┬─────┘
│ │
q_down_proj kv_down_proj
1024 → 16 1024 → 16
(低秩压缩) (低秩压缩)
│ │
RMSNorm RMSNorm
│ │
q_up_proj kv_up_proj
16 → n_head*(nope+rope) 16 → n_head*(nope+rope+v)
│ │
┌─────┴─────┐ ┌─────┴──────────┐
│ │ │ │ │
q_nope q_rope k_nope k_rope V
(72维) (24维) (72维) (24维) (96维)
│ │ │ │ │
│ apply_rope │ apply_rope │
│ │ │ │ │
└─────┬─────┘ └────┬───┘ │
│ │ │
Q (96维) K (96维) V (96维)
│ │ │
└──────────── Attention ────────┘───────────┘
关键点:
- Q和KV都先经过"下投影"压缩到低维(rank=16),再"上投影"展开
- Q和K都分成两部分:
nope(不旋转)和rope(旋转) - 只有rope部分应用RoPE位置编码
- 推理时,KV Cache只需存kv_latent(16维)+ k_rope(24维)= 40维
为什么要分nope和rope?
这是MLA设计中一个精妙的细节。回忆一下,RoPE是通过旋转来编码位置的。如果对整个K都应用旋转,那么从latent解压出来的K就包含了位置信息,KV Cache中的latent就不是位置无关的了——每个位置的latent不同,就失去了压缩的意义。
MLA的解决方案:把K分成两部分,一部分不旋转(从latent解压),一部分旋转(单独存储)。 这样latent保持位置无关,可以最大程度压缩;位置信息只通过独立的k_rope通道传递。
BuddyGPT的MLA参数
# BuddyGPT MLA 配置
q_lora_rank = 16 # Q的低秩压缩维度
kv_lora_rank = 16 # KV的低秩压缩维度
qk_nope_head_dim = 72 # Q/K中不旋转的维度
qk_rope_head_dim = 24 # Q/K中旋转的维度(应用RoPE)
v_head_dim = 96 # V的维度
n_heads = 16 # 注意力头数
每个head的Q和K有 72 + 24 = 96 维,其中72维通过latent压缩/解压(不受位置影响),24维直接用于RoPE位置编码。
MLA代码实现
class MLAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
self.q_lora_rank = config.q_lora_rank # 16
self.kv_lora_rank = config.kv_lora_rank # 16
self.qk_nope_head_dim = config.qk_nope_head_dim # 72
self.qk_rope_head_dim = config.qk_rope_head_dim # 24
self.v_head_dim = config.v_head_dim # 96
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim # 96
# ===== Q 路径: hidden → latent → 多头Q =====
# 下投影:压缩到低秩空间
self.q_down_proj = nn.Linear(
config.hidden_size, self.q_lora_rank, bias=False
) # 1024 → 16
self.q_down_layernorm = RMSNorm(self.q_lora_rank)
# 上投影:从低秩展开为所有head的Q
self.q_up_proj = nn.Linear(
self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False
) # 16 → 16 * 96 = 1536
# ===== KV 路径: hidden → latent → 多头K和V =====
self.kv_down_proj = nn.Linear(
config.hidden_size, self.kv_lora_rank, bias=False
) # 1024 → 16
self.kv_down_layernorm = RMSNorm(self.kv_lora_rank)
# 上投影:展开为K_nope + K_rope + V
self.kv_up_proj = nn.Linear(
self.kv_lora_rank,
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False
) # 16 → 16 * (72 + 96) = 2688
# 注意:k_rope不从latent来,而是单独投影
self.k_rope_proj = nn.Linear(
config.hidden_size, self.n_heads * self.qk_rope_head_dim, bias=False
) # 1024 → 16 * 24 = 384
# 输出投影
self.o_proj = nn.Linear(
self.n_heads * self.v_head_dim, config.hidden_size, bias=False
) # 16 * 96 = 1536 → 1024
# RoPE
self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)
def forward(self, hidden_states, position_ids=None):
B, L, _ = hidden_states.shape
# ===== Q: 压缩 → 归一化 → 展开 → 拆分nope和rope =====
q_latent = self.q_down_proj(hidden_states) # [B,L,16]
q_latent = self.q_down_layernorm(q_latent) # [B,L,16]
q = self.q_up_proj(q_latent) # [B,L,1536]
q = q.view(B, L, self.n_heads, self.qk_head_dim) # [B,L,16,96]
q = q.transpose(1, 2) # [B,16,L,96]
# 拆分:前72维不旋转,后24维旋转
q_nope, q_rope = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # [B,16,L,72] 和 [B,16,L,24]
# ===== KV: 压缩 → 归一化 → 展开 → 拆分 =====
kv_latent = self.kv_down_proj(hidden_states) # [B,L,16]
kv_latent = self.kv_down_layernorm(kv_latent) # [B,L,16]
kv = self.kv_up_proj(kv_latent) # [B,L,2688]
kv = kv.view(B, L, self.n_heads, -1).transpose(1, 2) # [B,16,L,168]
k_nope, v = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
) # [B,16,L,72] 和 [B,16,L,96]
# K的rope部分:单独从hidden投影(不经过latent)
k_rope = self.k_rope_proj(hidden_states) # [B,L,384]
k_rope = k_rope.view(B, L, self.n_heads, -1).transpose(1, 2) # [B,16,L,24]
# ===== 应用RoPE:只对rope部分旋转 =====
cos, sin = self.rotary_emb(hidden_states, position_ids)
q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin)
# ===== 拼接nope和rope部分,组成完整的Q和K =====
q = torch.cat([q_nope, q_rope], dim=-1) # [B,16,L,96]
k = torch.cat([k_nope, k_rope], dim=-1) # [B,16,L,96]
# ===== 标准Attention计算 =====
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
return self.o_proj(attn_output)
MLA vs GQA:KV Cache对比
让我们用具体数字算一下。以BuddyGPT 0.3B为例,序列长度2048,FP16精度:
GQA (8 KV heads, head_dim=64):
每token: 2(K+V) × 8 heads × 64 dim × 2 bytes = 2,048 bytes
每层: 2,048 × 2,048 tokens = 4,194,304 bytes ≈ 4 MB
24层: 4 MB × 24 = 96 MB
MLA (kv_lora_rank=16, rope_dim=24):
每token: (16 latent + 24 rope) × 2 bytes = 80 bytes
每层: 80 × 2,048 tokens = 163,840 bytes ≈ 0.16 MB
24层: 0.16 MB × 24 = 3.84 MB
压缩比: 96 MB / 3.84 MB ≈ 25x !!!
MLA在KV Cache上实现了约25倍的压缩!这意味着同样的显存可以支持25倍的batch size或25倍的序列长度。对于大规模部署来说,这是巨大的成本节省。
当然,MLA的计算量并没有减少——训练时还是要做完整的Q/K/V计算。它优化的是推理时的内存瓶颈。
个人体会
实现MLA之前,我觉得它一定很复杂。但真正写完代码后发现:MLA的实现并不难,关键是理解"压缩-解压缩"的思想。
它的本质就是一个autoencoder的思路:
- encoder(down_proj):把高维信息压缩到低维latent
- decoder(up_proj):从低维latent恢复高维信息
- 关键洞察:推理时只需要缓存压缩后的latent
这和图像压缩(JPEG)、视频编码(H.264)的思路如出一辙——信息是有冗余的,低秩近似可以用很少的参数捕获主要信息。
另一个体会是:MLA中nope/rope的分离设计非常精巧。它完美解决了"低秩压缩"和"位置编码"之间的矛盾——latent保持位置无关以最大化压缩,位置信息通过独立通道传递。这种"正交分解"的设计思路值得学习。
在BuddyGPT 0.7B中,我们正在实验MLA+MOE的组合。目前的初步结果是:MLA确实能在不损失质量的前提下大幅减少KV Cache,而MOE则能在不增加推理计算量的前提下扩大模型容量。两者的结合,就是DeepSeek-V3如此强大的秘密之一。
参考资料
- BuddyGPT项目地址
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model - MLA首次提出
- DeepSeek-V3 Technical Report - MLA的改进版本
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- LoRA: Low-Rank Adaptation of Large Language Models - 低秩近似的思想来源
- DeepSeek-V2 详解 - 知乎