Transformer 完整架构:手写一个基础版 Transformer

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:位置编码 (Positional Encoding) · BERT 家族详解


1. Transformer 整体架构

Transformer = Encoder(编码器)× N + Decoder(解码器)× N

Encoder(理解输入):
  Input → 词嵌入 + 位置编码
        → Multi-Head Self-Attention
        → Feed Forward Network
        → 输出:上下文表示

Decoder(生成输出):
  Output → 词嵌入 + 位置编码
         → Masked Multi-Head Self-Attention(看不到未来)
         → Cross-Attention(关注 Encoder 输出)
         → Feed Forward Network
         → 输出:下一个词的概率

2. 完整 Encoder 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-Attention + 残差 + LayerNorm
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed Forward + 残差 + LayerNorm
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8,
                 num_layers=6, d_ff=2048, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, input_ids, mask=None):
        x = self.embedding(input_ids) * math.sqrt(x.size(-1))
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

3. 残差连接与 LayerNorm

"""
残差连接(Residual Connection):
  输出 = 输入 + 子层(输入)
  → 梯度直接回传,缓解梯度消失
  → 允许训练更深的网络

LayerNorm vs BatchNorm:
  LayerNorm:每个样本独立归一化(NLP 常用)
  BatchNorm:每个特征跨样本归一化(CV 常用)
"""
class EncoderBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attention = MultiHeadAttention(d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        # 残差连接
        attn_out = self.attention(x, mask)
        x = self.norm(x + self.dropout(attn_out))
        return x

4. 完整 Transformer 模型

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
                 num_layers=6, d_ff=1024, dropout=0.1):
        super().__init__()
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, num_layers, d_ff, dropout
        )
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, num_classes),
        )

    def forward(self, input_ids, mask=None):
        encoder_output = self.encoder(input_ids, mask)
        # 取 [CLS] token 或平均池化
        pooled = encoder_output[:, 0]  # [CLS]
        return self.classifier(pooled)

5. 小结

Transformer 核心组件:

1. 词嵌入 + 位置编码 → 给词注入语义和位置
2. Multi-Head Self-Attention → 捕捉词间关系
3. Feed Forward Network → 逐位置非线性变换
4. 残差连接 + LayerNorm → 稳定训练
5. Stack N 层 → 逐步提取深层语义

这就是所有现代大模型的基础!

💡 记住:Transformer 的核心是 Self-Attention,一切都是围绕它构建的。Encoder 用于理解(BERT),Decoder 用于生成(GPT),完整架构用于翻译(原始 Transformer)。


🔗 扩展阅读