目录
现在刷爆眼球的GPT-4o、Gemini、Claude,底层核心架构全是2017年Google论文《Attention is All You Need》里提出的Transformer。
它直接抛弃了传统NLP依赖的RNN/LSTM/GRU循环结构,以及CNN的局部感受野,完全靠注意力机制搞定一切——这在深度学习历史上是一次颠覆性的范式转变,不仅解锁了训练超大模型的可能,还让模型能“一眼”记住超长文本的所有细节。
核心结构拆解
Transformer由对称的N个编码器层 + N个解码器层组成(论文里N=6),整体流程如下:
flowchart LR
subgraph Input [输入部分]
I[输入序列] --> W[词嵌入]
W --> P[位置编码]
P --> IW[加权嵌入]
end
subgraph Encoder [编码器N层]
IW --> MHA[多头自注意力]
MHA --> ADD1[残差+层归一化]
ADD1 --> FFN[前馈网络]
FFN --> ADD2[残差+层归一化]
ADD2 --> EO[编码器输出]
end
subgraph Decoder [解码器N层]
TI[目标序列前缀] --> TW[词嵌入]
TW --> TP[位置编码]
TP --> TWI[加权嵌入]
TWI --> MMHA[掩码多头自注意力]
MMHA --> TADD1[残差+层归一化]
TADD1 --> CA[交叉注意力]
CA --> TADD2[残差+层归一化]
TADD2 --> TFFN[前馈网络]
TFFN --> TADD3[残差+层归一化]
TADD3 --> DO[解码器输出]
end
EO -.->|K/V| CA
DO --> F[最终线性层+Softmax]
F --> O[输出概率分布]
相比传统模型的四大优势
- ✅ 完美并行化:不像RNN必须等前一个词处理完,Transformer能同时处理所有位置,训练速度提升几十倍
- ✅ 轻松捕获长依赖:注意力机制直接给任意两个词计算关联,不用像RNN那样“层层传话”
- ✅ 自带可解释性:输出的注意力权重可以可视化,清楚看到模型在“看”哪些词
- ✅ 超简单的扩展框架:堆层数、加维度、扩参数就行,从12层的BERT-base到上万层的GPT-4o都适用
自注意力机制详解
自注意力是Transformer的心脏,它能让序列里的每个位置,都“按需分配注意力”给其他所有位置。
直观理解:用查询-键-值买东西
想象你在超市找零食:
- 你(Query):当前想知道的信息——“有没有咸口薯片?”
- 货架标签(Key):其他零食用来匹配你的需求的特征——“甜口饼干”“咸口薯片”“无糖可乐”
- 零食本身(Value):其他零食的实际内容——“乐事黄瓜味”“奥利奥原味”…
- 最终选择(加权求和):根据标签匹配度(注意力权重),拿最匹配的零食
这种机制让模型在处理某个词时,能自主决定关注上下文中的哪些词。
PyTorch极简实现
import torch
import torch.nn as nn
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力的核心实现
Q/K/V: (batch_size, seq_len, d_k/d_v)
"""
d_k = Q.size(-1)
# 1. 计算Q和K的相似度(注意力分数)
# 除以sqrt(d_k)是为了防止分数太大导致softmax梯度消失
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 2. 应用掩码(比如解码器里不能看未来的词)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
# 3. 用softmax把分数转成0-1的权重,总和为1
attention_weights = torch.softmax(scores, dim=-1)
# 4. 加权求和V得到最终输出
output = torch.matmul(attention_weights, V)
return output, attention_weights
代码中除以根号下 d_k 的操作,也叫缩放操作,能避免点积结果过大引起的梯度消失问题。
多头注意力机制
单头注意力只能关注一种“语义模式”,比如“找主语和谓语的关系”;多头注意力就是同时开多个“语义雷达”,一个找主谓、一个找指代、一个找情感词,最后把结果拼起来,模型能力直接翻倍。
实现思路
多头注意力先把输入特征拆成若干个子空间,每个子空间独立计算注意力,最后把各头的结果拼接后再做一个线性变换。
PyTorch实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 四个线性变换:Q/K/V的生成,以及多头结果的合并
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x, batch_size):
"""把大的向量拆成多个小的子向量,并行计算"""
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2) # 变成(batch_size, num_heads, seq_len, d_k)
def combine_heads(self, x, batch_size):
"""把多个子向量拼回原来的形状"""
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, -1, self.d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. 线性变换生成初始Q/K/V
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# 2. 拆成多头
Q = self.split_heads(Q, batch_size)
K = self.split_heads(K, batch_size)
V = self.split_heads(V, batch_size)
# 3. 计算缩放点积注意力
scaled_attn, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 4. 合并多头
concat_attn = self.combine_heads(scaled_attn, batch_size)
# 5. 最终线性变换
output = self.W_o(concat_attn)
return output, attn_weights
位置编码(Positional Encoding)
Transformer没有循环结构,它“看不见词的顺序”——比如“我打你”和“你打我”在它眼里是一样的。所以我们需要显式给每个词加一个“位置标签”,也就是位置编码。
论文里的固定位置编码
论文用了正弦/余弦函数来生成固定的位置编码,这种方案有两大好处:
- 不需要额外训练参数,完全由函数生成
- 能自然泛化到比训练时更长的序列长度
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建一个足够大的位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
# 偶数维度用sin,奇数维度用cos
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为buffer,不会被优化器更新
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""把位置编码加到词嵌入上"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
编码器(Encoder)架构
编码器负责理解输入序列的上下文信息,把每个词转换成“包含全序列语义的向量”。
单个编码器层
每个编码器层由两个核心模块组成:多头自注意力 + 前馈网络(FFN),每个模块后面都接着残差连接 + 层归一化(这些细节后面单独讲)。
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 残差+层归一化的顺序是:LayerNorm(x + SubLayer(x))
# 先过自注意力
attn_out, _ = self.mha(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
# 再过前馈网络
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
完整编码器
完整编码器先经过词嵌入 + 位置编码,然后堆叠N个相同的编码器层。
class Encoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size, max_len, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.enc_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
def forward(self, x, mask=None):
# 论文里提到要把词嵌入乘以sqrt(d_model),防止位置编码的影响被淹没
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.enc_layers:
x = layer(x, mask)
return x
解码器(Decoder)架构
解码器负责自回归生成输出序列——每生成一个词,就把这个词加回输入,再生成下一个。
单个解码器层
解码器层比编码器层多了一个交叉注意力模块,整体由三个子层组成:
- 掩码多头自注意力:带掩码,防止解码器看到未来的词(保证自回归生成)
- 交叉注意力:Q来自解码器,K和V来自编码器的输出,让解码器“参考”输入序列的信息
- 前馈网络:和编码器完全一样
每个子层同样使用残差连接和层归一化,代码结构与编码器层相似,这里不再重复列出。
核心差异在于自注意力部分的掩码是一个上三角矩阵,确保位置 i 只能看到位置 i 及之前的词。
残差连接与层归一化
这两个是让深度Transformer能训练出来的关键!没有它们,堆个3层可能就梯度消失或爆炸了。
残差连接
思路非常简单:把子层的输入直接加到子层的输出上,即 输入 + 子层输出。这样即使子层学习效果不佳,也能直接把输入信息传递下去,避免信息丢失,同时大大缓解梯度消失问题。
层归一化
与批量归一化(BN)不同,层归一化(LN)是对每个样本的特征维度进行归一化。NLP任务中序列长度经常不一致,BN在这种场景下很不稳定,而LN完全不受影响,因此成了Transformer的标配。
两者的结合让Transformer可以轻松堆叠数十甚至上百层。
将编码器和解码器组装起来,再加上最后的线性层和softmax,就得到了完整的Transformer模型。由于篇幅限制,这里只给出核心拼接思路:编码器处理源序列,解码器以目标序列的前缀和编码器输出作为输入,逐步预测下一个词,最终输出概率分布。
如果你想直接运行完整代码,可以参考经典的 The Annotated Transformer 项目,或使用PyTorch内置的 nn.Transformer 模块快速验证。
实际应用与变体
现在主流的大模型都是Transformer的变体,主要分为三类:
Transformer是现代NLP的基石,但初学者不用一开始就手写完整模型——可以先用Hugging Face Transformers库调预训练模型,跑通几个小项目后再回来啃核心代码,效率会高很多!
总结
Transformer的成功在于它的简洁性和通用性:
- 用自注意力解决了长依赖和并行化
- 用残差+层归一化解决了深度训练问题
- 架构模块化,能轻松扩展到超大模型
理解了Transformer,你就推开了现代大模型的大门!