多头注意力 (Multi-Head Attention):让模型从不同维度观察语言

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:Self-Attention 自注意力计算 · 位置编码 (Positional Encoding)


1. 为什么需要多头?

1.1 单头的问题

单头 Attention = 一个查询向量,一次关注

问题:现实语言中,词与词之间的关系是多元的

"猫 追 老鼠"

单头可能只学到一种关系,无法同时捕捉:
  → 猫是主语,老鼠是宾语(语法关系)
  → 猫和老鼠是动物(语义类别)
  → 猫先于老鼠(时序关系)

1.2 多头的解决方案

Multi-Head = H 个独立的 Self-Attention 并行运行

每个头有自己的 W_q, W_k, W_v 参数
→ 每个头学习不同的关注模式

类比:
  头1:语法专家(捕捉主谓宾关系)
  头2:语义专家(捕捉实体关系)
  头3:位置专家(捕捉上下文顺序)
  头4:……

最终拼接多个头的输出 → 更丰富的语义表示

2. PyTorch 内置实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,  # PyTorch 新版支持!
        )
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + 残差连接 + LayerNorm
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed Forward + 残差连接 + LayerNorm
        ff_output = self.linear2(F.gelu(self.linear1(x)))
        x = self.norm2(x + self.dropout(ff_output))
        return x


# 使用
encoder_layer = TransformerEncoderLayer(d_model=512, num_heads=8)
x = torch.randn(32, 100, 512)  # (batch, seq_len, d_model)
output = encoder_layer(x)
print(output.shape)  # torch.Size([32, 100, 512])

3. 多头注意力的参数分析

"""
BERT-base 配置:
  d_model = 768(嵌入维度)
  num_heads = 12(12 个头)
  head_dim = 768 / 12 = 64(每个头 64 维)

  单个头的参数量:
  W_q: 768 × 64 = 49,152
  W_k: 768 × 64 = 49,152
  W_v: 768 × 64 = 49,152
  W_o: 768 × 768 = 589,824
  总计 ≈ 737K 参数(与单头相同!)

结论:多头不增加总参数量,只改变计算结构
"""

4. 经典配置

模型d_modelnum_headsd_ffdropout
BERT-base7681230720.1
BERT-large10241640960.1
GPT-2 small7681230720.1
GPT-2 medium10241640960.1

5. 小结

Multi-Head Attention:

d_model = num_heads × head_dim

每个头独立计算 Attention → 拼接 → 线性变换

优势:
- 不增加参数量
- 每个头学习不同的语义模式
- 最终表示更丰富

💡 经验之谈:多头数量通常是 d_model/64(如 d_model=768 → 12头;d_model=1024 → 16头),这是一个经验上效果较好的配置。


🔗 扩展阅读