MLA多头潜注意力:DeepSeek-V3的核心创新

从GQA到MLA:还能再省吗?

上一篇我们讲了GQA,通过让多个Q head共享KV来减少KV Cache。BuddyGPT 0.3B用16Q/8KV的配置,比MHA节省了50%的KV Cache。

但DeepSeek团队问了一个更根本的问题:KV Cache存的一定要是完整的Key和Value吗?能不能存一个更小的"压缩版"?

这就是MLA(Multi-head Latent Attention)的出发点。

MLA核心思想:低秩压缩

MLA的核心思路用三句话概括:

  1. 训练时:先把hidden state压缩到一个低维的latent向量,再从latent展开为完整的Q/K/V
  2. 推理时:KV Cache只需要存低维的latent,而不是完整的K和V
  3. 效果:KV Cache可以压缩到原来的1/6甚至更小,而质量几乎不损失

用图来理解:


传统GQA的KV Cache:
┌───────────────────────────────────────┐
│  缓存完整的K和V                        │
│  K: [n_kv_heads × head_dim] = 8×64    │ = 512维
│  V: [n_kv_heads × head_dim] = 8×64    │ = 512维
│  总计: 1024维/token                    │
└───────────────────────────────────────┘

MLA的KV Cache:
┌───────────────────────────────────────┐
│  只缓存低维latent                      │
│  kv_latent: [kv_lora_rank] = 16维      │
│  k_rope:    [qk_rope_head_dim] = 24维  │ (位置编码部分需要单独存)
│  总计: 40维/token                      │ ← 比GQA小25倍!
└───────────────────────────────────────┘

MLA的数据流

MLA的完整数据流比GQA复杂,但逻辑很清晰。我们分Q和KV两条路径来看:


                         hidden_states [B, L, hidden_size=1024]
                                    │
                    ┌───────────────┴───────────────┐
                    │                               │
              ┌─────v─────┐                   ┌─────v─────┐
              │ Q 路径     │                   │ KV 路径    │
              └─────┬─────┘                   └─────┬─────┘
                    │                               │
              q_down_proj                     kv_down_proj
              1024 → 16                       1024 → 16
              (低秩压缩)                       (低秩压缩)
                    │                               │
              RMSNorm                          RMSNorm
                    │                               │
              q_up_proj                        kv_up_proj
              16 → n_head*(nope+rope)          16 → n_head*(nope+rope+v)
                    │                               │
              ┌─────┴─────┐                   ┌─────┴──────────┐
              │           │                   │        │       │
           q_nope      q_rope              k_nope   k_rope    V
           (72维)      (24维)              (72维)   (24维)   (96维)
              │           │                   │        │       │
              │       apply_rope              │    apply_rope  │
              │           │                   │        │       │
              └─────┬─────┘                   └────┬───┘       │
                    │                              │           │
                    Q (96维)                       K (96维)    V (96维)
                    │                              │           │
                    └──────────── Attention ────────┘───────────┘

关键点:

  • Q和KV都先经过"下投影"压缩到低维(rank=16),再"上投影"展开
  • Q和K都分成两部分:nope(不旋转)和 rope(旋转)
  • 只有rope部分应用RoPE位置编码
  • 推理时,KV Cache只需存kv_latent(16维)+ k_rope(24维)= 40维

为什么要分nope和rope?

这是MLA设计中一个精妙的细节。回忆一下,RoPE是通过旋转来编码位置的。如果对整个K都应用旋转,那么从latent解压出来的K就包含了位置信息,KV Cache中的latent就不是位置无关的了——每个位置的latent不同,就失去了压缩的意义。

MLA的解决方案:把K分成两部分,一部分不旋转(从latent解压),一部分旋转(单独存储)。 这样latent保持位置无关,可以最大程度压缩;位置信息只通过独立的k_rope通道传递。

BuddyGPT的MLA参数


# BuddyGPT MLA 配置
q_lora_rank = 16        # Q的低秩压缩维度
kv_lora_rank = 16       # KV的低秩压缩维度
qk_nope_head_dim = 72   # Q/K中不旋转的维度
qk_rope_head_dim = 24   # Q/K中旋转的维度(应用RoPE)
v_head_dim = 96          # V的维度
n_heads = 16             # 注意力头数

每个head的Q和K有 72 + 24 = 96 维,其中72维通过latent压缩/解压(不受位置影响),24维直接用于RoPE位置编码。

MLA代码实现


class MLAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.num_attention_heads
        self.q_lora_rank = config.q_lora_rank          # 16
        self.kv_lora_rank = config.kv_lora_rank        # 16
        self.qk_nope_head_dim = config.qk_nope_head_dim  # 72
        self.qk_rope_head_dim = config.qk_rope_head_dim  # 24
        self.v_head_dim = config.v_head_dim            # 96
        self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim  # 96

        # ===== Q 路径: hidden → latent → 多头Q =====
        # 下投影:压缩到低秩空间
        self.q_down_proj = nn.Linear(
            config.hidden_size, self.q_lora_rank, bias=False
        )  # 1024 → 16
        self.q_down_layernorm = RMSNorm(self.q_lora_rank)
        # 上投影:从低秩展开为所有head的Q
        self.q_up_proj = nn.Linear(
            self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False
        )  # 16 → 16 * 96 = 1536

        # ===== KV 路径: hidden → latent → 多头K和V =====
        self.kv_down_proj = nn.Linear(
            config.hidden_size, self.kv_lora_rank, bias=False
        )  # 1024 → 16
        self.kv_down_layernorm = RMSNorm(self.kv_lora_rank)
        # 上投影:展开为K_nope + K_rope + V
        self.kv_up_proj = nn.Linear(
            self.kv_lora_rank,
            self.n_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False
        )  # 16 → 16 * (72 + 96) = 2688
        # 注意:k_rope不从latent来,而是单独投影
        self.k_rope_proj = nn.Linear(
            config.hidden_size, self.n_heads * self.qk_rope_head_dim, bias=False
        )  # 1024 → 16 * 24 = 384

        # 输出投影
        self.o_proj = nn.Linear(
            self.n_heads * self.v_head_dim, config.hidden_size, bias=False
        )  # 16 * 96 = 1536 → 1024

        # RoPE
        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)

    def forward(self, hidden_states, position_ids=None):
        B, L, _ = hidden_states.shape

        # ===== Q: 压缩 → 归一化 → 展开 → 拆分nope和rope =====
        q_latent = self.q_down_proj(hidden_states)               # [B,L,16]
        q_latent = self.q_down_layernorm(q_latent)               # [B,L,16]
        q = self.q_up_proj(q_latent)                             # [B,L,1536]
        q = q.view(B, L, self.n_heads, self.qk_head_dim)        # [B,L,16,96]
        q = q.transpose(1, 2)                                    # [B,16,L,96]
        # 拆分:前72维不旋转,后24维旋转
        q_nope, q_rope = torch.split(
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )  # [B,16,L,72] 和 [B,16,L,24]

        # ===== KV: 压缩 → 归一化 → 展开 → 拆分 =====
        kv_latent = self.kv_down_proj(hidden_states)             # [B,L,16]
        kv_latent = self.kv_down_layernorm(kv_latent)            # [B,L,16]
        kv = self.kv_up_proj(kv_latent)                          # [B,L,2688]
        kv = kv.view(B, L, self.n_heads, -1).transpose(1, 2)    # [B,16,L,168]
        k_nope, v = torch.split(
            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
        )  # [B,16,L,72] 和 [B,16,L,96]

        # K的rope部分:单独从hidden投影(不经过latent)
        k_rope = self.k_rope_proj(hidden_states)                 # [B,L,384]
        k_rope = k_rope.view(B, L, self.n_heads, -1).transpose(1, 2)  # [B,16,L,24]

        # ===== 应用RoPE:只对rope部分旋转 =====
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin)

        # ===== 拼接nope和rope部分,组成完整的Q和K =====
        q = torch.cat([q_nope, q_rope], dim=-1)  # [B,16,L,96]
        k = torch.cat([k_nope, k_rope], dim=-1)  # [B,16,L,96]

        # ===== 标准Attention计算 =====
        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
        return self.o_proj(attn_output)

MLA vs GQA:KV Cache对比

让我们用具体数字算一下。以BuddyGPT 0.3B为例,序列长度2048,FP16精度:


GQA (8 KV heads, head_dim=64):
  每token: 2(K+V) × 8 heads × 64 dim × 2 bytes = 2,048 bytes
  每层:    2,048 × 2,048 tokens = 4,194,304 bytes ≈ 4 MB
  24层:    4 MB × 24 = 96 MB

MLA (kv_lora_rank=16, rope_dim=24):
  每token: (16 latent + 24 rope) × 2 bytes = 80 bytes
  每层:    80 × 2,048 tokens = 163,840 bytes ≈ 0.16 MB
  24层:    0.16 MB × 24 = 3.84 MB

压缩比: 96 MB / 3.84 MB ≈ 25x !!!

MLA在KV Cache上实现了约25倍的压缩!这意味着同样的显存可以支持25倍的batch size或25倍的序列长度。对于大规模部署来说,这是巨大的成本节省。

当然,MLA的计算量并没有减少——训练时还是要做完整的Q/K/V计算。它优化的是推理时的内存瓶颈。

个人体会

实现MLA之前,我觉得它一定很复杂。但真正写完代码后发现:MLA的实现并不难,关键是理解"压缩-解压缩"的思想。

它的本质就是一个autoencoder的思路:

  • encoder(down_proj):把高维信息压缩到低维latent
  • decoder(up_proj):从低维latent恢复高维信息
  • 关键洞察:推理时只需要缓存压缩后的latent

这和图像压缩(JPEG)、视频编码(H.264)的思路如出一辙——信息是有冗余的,低秩近似可以用很少的参数捕获主要信息。

另一个体会是:MLA中nope/rope的分离设计非常精巧。它完美解决了"低秩压缩"和"位置编码"之间的矛盾——latent保持位置无关以最大化压缩,位置信息通过独立通道传递。这种"正交分解"的设计思路值得学习。

在BuddyGPT 0.7B中,我们正在实验MLA+MOE的组合。目前的初步结果是:MLA确实能在不损失质量的前提下大幅减少KV Cache,而MOE则能在不增加推理计算量的前提下扩大模型容量。两者的结合,就是DeepSeek-V3如此强大的秘密之一。

参考资料