#位置编码 (Positional Encoding):给没有顺序的矩阵注入"位置感"
📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:多头注意力 (Multi-Head Attention) · Transformer 完整架构
#1. 为什么需要位置编码?
#1.1 Self-Attention 的缺陷
Self-Attention:完全并行、无视位置
输入:"I love NLP" 和 "NLP love I"
→ 变换后的 Attention 结果完全相同!
→ 模型无法区分词语顺序!
"我 打 你" vs "你 打 我"
→ 词序颠倒,意思相反,但 Attention 输出相同!#1.2 解决方案
给每个位置一个唯一的"位置信号" → 加到词嵌入上
词嵌入:捕捉语义信息
位置编码:捕捉位置信息
最终表示 = 词嵌入 + 位置编码#2. 正弦/余弦位置编码
#2.1 原始 Transformer 公式
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 偶数维度用 sin,奇数维度用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# (max_len, d_model) → (1, max_len, d_model)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)#2.2 为什么用正弦余弦?
正弦余弦位置编码的特性:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
优势:
1. 任意两个位置的差,可以用线性变换表示
→ PE(pos+k) 可以用 PE(pos) 的线性组合表示!
→ 模型可以学习相对位置关系!
2. 可以泛化到训练时未见过的序列长度
3. 计算高效,无需学习参数#3. 可学习位置编码
class LearnedPositionalEncoding(nn.Module):
"""可学习的位置编码"""
def __init__(self, d_model, max_len=5000):
super().__init__()
self.position_embedding = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
position_embeddings = self.position_embedding(positions)
return x + position_embeddings#4. RoPE(旋转位置编码)
"""
RoPE = Rotary Position Embedding
现代 LLM(如 LLaMA、ChatGLM)使用的位置编码
核心思想:用旋转矩阵对 Q 和 K 进行变换,
直接融入位置信息,不需要加到词嵌入上!
"""
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
def forward(self, x, seq_len=None):
# x: (batch, num_heads, seq_len, head_dim)
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs.sin(), freqs.cos()], dim=-1)
return emb.cos(), emb.sin()#5. 小结
位置编码三剑客:
1. 绝对位置编码(Sin/Cos):唯一位置信息,位置可泛化
2. 可学习位置编码:端到端学习,效果通常更好
3. RoPE(旋转):现代 LLM 标配,融入 Q/K 旋转
2026 年主流:RoPE(支持更长上下文)
"""💡 记住:没有位置编码的 Transformer 等价于词袋模型(Bag of Words),完全无法区分词语顺序!
🔗 扩展阅读

