多头注意力 (Multi-Head Attention):让模型从不同维度观察语言
📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 前置基础:Self-Attention 自注意力计算 · 位置编码 (Positional Encoding)
1. 为什么单头注意力不够用?
我们先快速回顾一下单头自注意力的核心逻辑:对输入序列的每个词,生成唯一的查询、键、值向量,然后一次全局比较,得到这个词的最终加权表示。
这种“一站式查询”看似高效,但现实语言太复杂了——单个词和其他词的联系,往往不是单一维度的。
1.1 举个生活化的反例
比如拿句子 「程序员小王熬夜写了一篇技术博客」 来说:
- 从语法结构看,「小王」是「写」的主语,「博客」是「写」的宾语,「熬夜」是「写」的状语;
- 从语义角色看,「小王」和「程序员」是身份关联,「技术」和「博客」是主题修饰;
- 从隐含逻辑看,「熬夜」大概率暗示这篇博客“赶稿但可能干货满满”。
如果只给单头注意力一组查询、键、值,它很可能顾此失彼:要么只抓得住最明显的主谓宾,要么分散权重抓一堆无关细节,没法精准覆盖所有有用的多元关系。
2. 多头注意力的核心思路:分而治之
Transformer 团队提出的解法很巧妙:把单头的“万能专家”拆成H个“专精专家”并行协作。
2.1 基本流程拆解
- 头部分割:把输入的嵌入维度(记为
d_model)平均拆成H份,每份叫head_dim(也就是d_model = num_heads × head_dim); - 独立投影:每个头有自己专属的查询、键、值投影矩阵,并行生成自己的子查询、子键、子值;
- 分头计算:每个头用自己的子向量组,独立做一次自注意力计算,得到自己的子输出;
- 拼接融合:把所有头的子输出按顺序拼回
d_model维度; - 线性整合:用一个整合矩阵对拼接后的向量做线性变换,得到最终的多头注意力输出。
2.2 用“专家团队”的类比再讲一遍
我们可以把多头注意力看成一个NLP 语义分析小组:
- 头1是「语法分析师」:只关注词的主谓宾、定语、状语这些结构关系;
- 头2是「实体识别员」:重点抓人名、地名、物品名之间的身份/主题关联;
- 头3是「情感/逻辑挖掘师」:专门看有没有“熬夜赶稿”“开心分享”这种隐含的信息;
- ……(可以根据需求设置更多不同专精的头)
每个专家单独处理完自己的任务后,组长(也就是最后的线性整合矩阵)把大家的分析报告拼在一起,梳理成一份完整、全面的报告,就是最终的语义表示了。
3. PyTorch 内置实现快速上手
PyTorch 的 nn.MultiheadAttention 已经帮我们封装好了所有核心逻辑,不需要自己手写投影、分割、拼接这些步骤,直接调用即可。
下面是一个包含多头注意力的 Transformer 编码器单层实现,可以直接复制测试:
4. 多头注意力的参数:没有增加总数量!
很多人会担心“拆成H个头会不会参数量爆炸”?其实完全不会——多头注意力的总参数量和单头注意力是一样的!
我们用 NLP 入门经典模型 BERT-base 的配置来验证一下:
- BERT-base 核心参数:
d_model=768,num_heads=12,head_dim=768/12=64
4.1 单头注意力的参数量
单头需要4个投影矩阵:
- 子查询、子键、子值其实还是用完整的
d_model→d_model(因为没有分割) - 子输出整合矩阵也是
d_model→d_model - 总参数量 = 3×(768×768) + 768×768 = 4×768×768 ≈ 2.36M
4.2 多头注意力的参数量
拆成12个头后:
- 每个头的投影矩阵是
d_model→head_dim(每个头只负责一小部分) - 12个头总共用3×12×(768×64) = 3×768×(12×64) = 3×768×768
- 最后的整合矩阵还是
d_model→d_model= 768×768 - 总参数量 = 同样是4×768×768 ≈ 2.36M!
结论很明确:多头注意力只是改变了计算的“组织结构”,把参数拆到不同的专精头里,但总规模完全没变——属于“花一样的钱,买更全面的服务”。
5. 经典模型的多头配置参考
实际开发中,我们不需要自己瞎凑配置——遵循主流大模型的经验值通常效果最好。
下表整理了几个入门和进阶常用模型的多头相关核心配置:
5.1 配置经验小总结
- 多头数和嵌入维度的关系:通常遵循
head_dim = 64(这是Transformer原论文里的经验最优值),所以num_heads = d_model / 64。比如d_model=768→12头,d_model=1024→16头,d_model=4096→32头,正好和上面的经典配置一致; - 前馈中间维度:通常是
d_model的4倍左右(BERT-base是3072=4×768,BERT-large是4096=4×1024,Llama 2 7B有特殊设计,但也在3-4倍附近); - 丢弃率:微调阶段通常用0.1左右,预训练大模型后期可能会降到0。
6. 快速小结
我们用3句话总结一下多头注意力的核心:
- 解决的问题:单头注意力无法同时捕捉语言的多元关系;
- 核心思路:把嵌入维度平均拆成H份,每个头用专属参数学习不同的语义模式,最后拼接整合;
- 核心优势:花和单头一样的参数量,得到更全面、更精准的语义表示。
💡 实际开发小技巧: 如果你的显存有限,可以适当减小嵌入维度
d_model,同时保持head_dim=64——这样可以同时降低参数量和计算量,但尽量不要随便改head_dim,因为64是经过大量验证的经验最优值。
🔗 扩展学习资源
- Attention is All You Need 原论文(Transformer的开山之作,建议精读前半部分原理)
- Hugging Face Transformers 库文档(里面有现成的预训练多头注意力模型,直接调用微调即可)
- Transformer 原作者的讲解视频(英文但有字幕,讲得非常通俗易懂)

