多头注意力 (Multi-Head Attention):让模型从不同维度观察语言

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 前置基础:Self-Attention 自注意力计算 · 位置编码 (Positional Encoding)


1. 为什么单头注意力不够用?

我们先快速回顾一下单头自注意力的核心逻辑:对输入序列的每个词,生成唯一的查询、键、值向量,然后一次全局比较,得到这个词的最终加权表示。

这种“一站式查询”看似高效,但现实语言太复杂了——单个词和其他词的联系,往往不是单一维度的。

1.1 举个生活化的反例

比如拿句子 「程序员小王熬夜写了一篇技术博客」 来说:

  • 语法结构看,「小王」是「写」的主语,「博客」是「写」的宾语,「熬夜」是「写」的状语;
  • 语义角色看,「小王」和「程序员」是身份关联,「技术」和「博客」是主题修饰;
  • 隐含逻辑看,「熬夜」大概率暗示这篇博客“赶稿但可能干货满满”。

如果只给单头注意力一组查询、键、值,它很可能顾此失彼:要么只抓得住最明显的主谓宾,要么分散权重抓一堆无关细节,没法精准覆盖所有有用的多元关系。


2. 多头注意力的核心思路:分而治之

Transformer 团队提出的解法很巧妙:把单头的“万能专家”拆成H个“专精专家”并行协作

2.1 基本流程拆解

  1. 头部分割:把输入的嵌入维度(记作 d_model平均拆成 H 份,每份叫 head_dim(也就是说 d_model = num_heads × head_dim);
  2. 独立投影:每个头有自己专属的查询、键、值投影矩阵,并行生成自己的子查询、子键、子值;
  3. 分头计算:每个头用自己的子向量组,独立做一次自注意力计算,得到自己的子输出;
  4. 拼接融合:把所有头的子输出按顺序拼回 d_model 维度;
  5. 线性整合:用一个整合矩阵对拼接后的向量做线性变换,得到最终的多头注意力输出。

2.2 用“专家团队”的类比再讲一遍

我们可以把多头注意力看成一个NLP 语义分析小组

  • 头1是「语法分析师」:只关注词的主谓宾、定语、状语这些结构关系;
  • 头2是「实体识别员」:重点抓人名、地名、物品名之间的身份/主题关联;
  • 头3是「情感/逻辑挖掘师」:专门看有没有“熬夜赶稿”“开心分享”这种隐含的信息;
  • ……(可以根据需求设置更多不同专精的头)

每个专家单独处理完自己的任务后,组长(也就是最后的线性整合矩阵)把大家的分析报告拼在一起,梳理成一份完整、全面的报告,就是最终的语义表示了。


3. PyTorch 内置实现快速上手

PyTorch 的 nn.MultiheadAttention 已经帮我们封装好了所有核心逻辑,不需要自己手写投影、分割、拼接这些步骤,直接调用即可。

下面是一个包含多头注意力的 Transformer 编码器单层实现,可以直接复制测试:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 带多头自注意力的 Transformer 编码器单层
class TransformerEncoderLayerDemo(nn.Module):
    def __init__(
        self,
        d_model: int = 512,      # 输入/输出的嵌入维度
        num_heads: int = 8,       # 多头数量
        d_ff: int = 2048,         # 前馈网络的中间维度
        dropout: float = 0.1,      # 防止过拟合的丢弃率
    ):
        super().__init__()
        
        # 核心:多头自注意力模块
        # 注意:PyTorch 2.x+ 推荐用 batch_first=True,维度顺序更直观
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )
        
        # 前馈网络(Feed Forward Network)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
        # 残差连接 + LayerNorm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        # 第一步:多头自注意力 + 残差 + LayerNorm
        # self_attn 的输入:query, key, value,这里用自注意力,三个都是x
        # 输出:(多头注意力结果, 注意力权重矩阵,可选是否返回)
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))  # 残差连接:输入 + 处理后的输出

        # 第二步:前馈网络 + 残差 + LayerNorm
        ff_output = self.linear2(F.gelu(self.linear1(x)))
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


# ---------------- 测试代码 ----------------
if __name__ == "__main__":
    # 初始化编码器单层
    encoder_layer = TransformerEncoderLayerDemo(d_model=512, num_heads=8)
    
    # 构造测试输入:(batch_size, seq_len, d_model)
    # batch_size=32:一次处理32个句子
    # seq_len=100:每个句子有100个词(不足补全,超过截断)
    # d_model=512:每个词的嵌入维度是512
    test_input = torch.randn(32, 100, 512)
    
    # 前向传播
    test_output = encoder_layer(test_input)
    
    # 输出维度应该和输入一致!
    print(f"输入维度: {test_input.shape}")    # torch.Size([32, 100, 512])
    print(f"输出维度: {test_output.shape}")  # torch.Size([32, 100, 512])

4. 多头注意力的参数:没有增加总数量!

很多人会担心“拆成 H 个头会不会参数量爆炸”?其实完全不会——多头注意力的总参数量和单头注意力是一样的

我们用 NLP 入门经典模型 BERT-base 的配置来验证一下:

  • BERT-base 核心参数:d_model=768num_heads=12,那么 head_dim = 768/12 = 64

4.1 单头注意力的参数量

单头需要 4 个投影矩阵:

  • 查询、键、值的投影矩阵都是 d_model × d_model(因为没有分割,整个维度一起做变换)
  • 最后的整合矩阵也是 d_model × d_model
  • 总参数量 = 3 × (768 × 768) + 768 × 768 = 4 × 768 × 768 ≈ 2.36M

4.2 多头注意力的参数量

拆成 12 个头后:

  • 每个头的查询、键、值投影矩阵变成了 d_model × head_dim(只负责一部分维度)
  • 12 个头总共用 3 × 12 × (768 × 64) = 3 × 768 × (12 × 64) = 3 × 768 × 768
  • 最后的整合矩阵还是 d_model × d_model = 768 × 768
  • 总参数量 = 同样是 4 × 768 × 768 ≈ 2.36M

结论很明确:多头注意力只是改变了计算的“组织结构”,把参数拆到不同的专精头里,但总规模完全没变——属于“花一样的钱,买更全面的服务”


5. 经典模型的多头配置参考

实际开发中,我们不需要自己瞎凑配置——遵循主流大模型的经验值通常效果最好。

下表整理了几个入门和进阶常用模型的多头相关核心配置:

模型名称d_model(嵌入维度)num_heads(多头数)d_ff(前馈中间维度)dropout(丢弃率)
BERT-base7681230720.1
BERT-large10241640960.1
GPT-2 small7681230720.1
GPT-2 medium10241640960.1
Llama 2 7B409632110080.0

5.1 配置经验小总结

  1. 多头数和嵌入维度的关系:通常遵循 head_dim = 64(这是 Transformer 原论文里的经验最优值),所以 num_heads = d_model / 64。比如 d_model=768→12 头,d_model=1024→16 头,d_model=4096→32 头,正好和上面的经典配置一致;
  2. 前馈中间维度:通常是 d_model 的 4 倍左右(BERT-base 是 3072=4×768,BERT-large 是 4096=4×1024,Llama 2 7B 有特殊设计,但也在 3-4 倍附近);
  3. 丢弃率:微调阶段通常用 0.1 左右,预训练大模型后期可能会降到 0。

6. 快速小结

我们用 3 句话总结一下多头注意力的核心:

  1. 解决的问题:单头注意力无法同时捕捉语言的多元关系;
  2. 核心思路:把嵌入维度平均拆成 H 份,每个头用专属参数学习不同的语义模式,最后拼接整合;
  3. 核心优势:花和单头一样的参数量,得到更全面、更精准的语义表示。

💡 实际开发小技巧: 如果你的显存有限,可以适当减小嵌入维度 d_model,同时保持 head_dim=64——这样可以同时降低参数量和计算量,但尽量不要随便改 head_dim,因为 64 是经过大量验证的经验最优值。


🔗 扩展学习资源