多头注意力 (Multi-Head Attention):让模型从不同维度观察语言
📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 前置基础:Self-Attention 自注意力计算 · 位置编码 (Positional Encoding)
1. 为什么单头注意力不够用?
我们先快速回顾一下单头自注意力的核心逻辑:对输入序列的每个词,生成唯一的查询、键、值向量,然后一次全局比较,得到这个词的最终加权表示。
这种“一站式查询”看似高效,但现实语言太复杂了——单个词和其他词的联系,往往不是单一维度的。
1.1 举个生活化的反例
比如拿句子 「程序员小王熬夜写了一篇技术博客」 来说:
- 从语法结构看,「小王」是「写」的主语,「博客」是「写」的宾语,「熬夜」是「写」的状语;
- 从语义角色看,「小王」和「程序员」是身份关联,「技术」和「博客」是主题修饰;
- 从隐含逻辑看,「熬夜」大概率暗示这篇博客“赶稿但可能干货满满”。
如果只给单头注意力一组查询、键、值,它很可能顾此失彼:要么只抓得住最明显的主谓宾,要么分散权重抓一堆无关细节,没法精准覆盖所有有用的多元关系。
2. 多头注意力的核心思路:分而治之
Transformer 团队提出的解法很巧妙:把单头的“万能专家”拆成H个“专精专家”并行协作。
2.1 基本流程拆解
- 头部分割:把输入的嵌入维度(记作
d_model)平均拆成 H 份,每份叫head_dim(也就是说d_model = num_heads × head_dim); - 独立投影:每个头有自己专属的查询、键、值投影矩阵,并行生成自己的子查询、子键、子值;
- 分头计算:每个头用自己的子向量组,独立做一次自注意力计算,得到自己的子输出;
- 拼接融合:把所有头的子输出按顺序拼回
d_model维度; - 线性整合:用一个整合矩阵对拼接后的向量做线性变换,得到最终的多头注意力输出。
2.2 用“专家团队”的类比再讲一遍
我们可以把多头注意力看成一个NLP 语义分析小组:
- 头1是「语法分析师」:只关注词的主谓宾、定语、状语这些结构关系;
- 头2是「实体识别员」:重点抓人名、地名、物品名之间的身份/主题关联;
- 头3是「情感/逻辑挖掘师」:专门看有没有“熬夜赶稿”“开心分享”这种隐含的信息;
- ……(可以根据需求设置更多不同专精的头)
每个专家单独处理完自己的任务后,组长(也就是最后的线性整合矩阵)把大家的分析报告拼在一起,梳理成一份完整、全面的报告,就是最终的语义表示了。
3. PyTorch 内置实现快速上手
PyTorch 的 nn.MultiheadAttention 已经帮我们封装好了所有核心逻辑,不需要自己手写投影、分割、拼接这些步骤,直接调用即可。
下面是一个包含多头注意力的 Transformer 编码器单层实现,可以直接复制测试:
4. 多头注意力的参数:没有增加总数量!
很多人会担心“拆成 H 个头会不会参数量爆炸”?其实完全不会——多头注意力的总参数量和单头注意力是一样的!
我们用 NLP 入门经典模型 BERT-base 的配置来验证一下:
- BERT-base 核心参数:
d_model=768,num_heads=12,那么head_dim = 768/12 = 64
4.1 单头注意力的参数量
单头需要 4 个投影矩阵:
- 查询、键、值的投影矩阵都是
d_model × d_model(因为没有分割,整个维度一起做变换) - 最后的整合矩阵也是
d_model × d_model - 总参数量 = 3 × (768 × 768) + 768 × 768 = 4 × 768 × 768 ≈ 2.36M
4.2 多头注意力的参数量
拆成 12 个头后:
- 每个头的查询、键、值投影矩阵变成了
d_model × head_dim(只负责一部分维度) - 12 个头总共用 3 × 12 × (768 × 64) = 3 × 768 × (12 × 64) = 3 × 768 × 768
- 最后的整合矩阵还是
d_model × d_model= 768 × 768 - 总参数量 = 同样是 4 × 768 × 768 ≈ 2.36M!
结论很明确:多头注意力只是改变了计算的“组织结构”,把参数拆到不同的专精头里,但总规模完全没变——属于“花一样的钱,买更全面的服务”。
5. 经典模型的多头配置参考
实际开发中,我们不需要自己瞎凑配置——遵循主流大模型的经验值通常效果最好。
下表整理了几个入门和进阶常用模型的多头相关核心配置:
5.1 配置经验小总结
- 多头数和嵌入维度的关系:通常遵循
head_dim = 64(这是 Transformer 原论文里的经验最优值),所以num_heads = d_model / 64。比如d_model=768→12 头,d_model=1024→16 头,d_model=4096→32 头,正好和上面的经典配置一致; - 前馈中间维度:通常是
d_model的 4 倍左右(BERT-base 是 3072=4×768,BERT-large 是 4096=4×1024,Llama 2 7B 有特殊设计,但也在 3-4 倍附近); - 丢弃率:微调阶段通常用 0.1 左右,预训练大模型后期可能会降到 0。
6. 快速小结
我们用 3 句话总结一下多头注意力的核心:
- 解决的问题:单头注意力无法同时捕捉语言的多元关系;
- 核心思路:把嵌入维度平均拆成 H 份,每个头用专属参数学习不同的语义模式,最后拼接整合;
- 核心优势:花和单头一样的参数量,得到更全面、更精准的语义表示。
💡 实际开发小技巧: 如果你的显存有限,可以适当减小嵌入维度
d_model,同时保持head_dim=64——这样可以同时降低参数量和计算量,但尽量不要随便改head_dim,因为 64 是经过大量验证的经验最优值。
🔗 扩展学习资源
- Attention is All You Need 原论文(Transformer 的开山之作,建议精读前半部分原理)
- Hugging Face Transformers 库文档(里面有现成的预训练多头注意力模型,直接调用微调即可)
- Transformer 原作者的讲解视频(英文但有字幕,讲得非常通俗易懂)

