位置编码 (Positional Encoding):给没有顺序的矩阵注入"位置感"

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:多头注意力 (Multi-Head Attention) · Transformer 完整架构


1. 为什么需要位置编码?

1.1 Self-Attention 的缺陷

Self-Attention:完全并行、无视位置

输入:"I love NLP" 和 "NLP love I"
→ 变换后的 Attention 结果完全相同!
→ 模型无法区分词语顺序!

"我 打 你" vs "你 打 我"
→ 词序颠倒,意思相反,但 Attention 输出相同!

1.2 解决方案

给每个位置一个唯一的"位置信号" → 加到词嵌入上

词嵌入:捕捉语义信息
位置编码:捕捉位置信息

最终表示 = 词嵌入 + 位置编码

2. 正弦/余弦位置编码

2.1 原始 Transformer 公式

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # 偶数维度用 sin,奇数维度用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # (max_len, d_model) → (1, max_len, d_model)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

2.2 为什么用正弦余弦?

正弦余弦位置编码的特性:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

优势:
1. 任意两个位置的差,可以用线性变换表示
   → PE(pos+k) 可以用 PE(pos) 的线性组合表示!
   → 模型可以学习相对位置关系!

2. 可以泛化到训练时未见过的序列长度
3. 计算高效,无需学习参数

3. 可学习位置编码

class LearnedPositionalEncoding(nn.Module):
    """可学习的位置编码"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.position_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        position_embeddings = self.position_embedding(positions)
        return x + position_embeddings

4. RoPE(旋转位置编码)

"""
RoPE = Rotary Position Embedding
现代 LLM(如 LLaMA、ChatGLM)使用的位置编码

核心思想:用旋转矩阵对 Q 和 K 进行变换,
直接融入位置信息,不需要加到词嵌入上!
"""
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    def forward(self, x, seq_len=None):
        # x: (batch, num_heads, seq_len, head_dim)
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat([freqs.sin(), freqs.cos()], dim=-1)
        return emb.cos(), emb.sin()

5. 小结

位置编码三剑客:

1. 绝对位置编码(Sin/Cos):唯一位置信息,位置可泛化
2. 可学习位置编码:端到端学习,效果通常更好
3. RoPE(旋转):现代 LLM 标配,融入 Q/K 旋转

2026 年主流:RoPE(支持更长上下文)
"""

💡 记住:没有位置编码的 Transformer 等价于词袋模型(Bag of Words),完全无法区分词语顺序!


🔗 扩展阅读