序列到序列模型 (Seq2Seq):Encoder-Decoder 架构

📂 所属阶段:第二阶段 — 深度学习与序列模型(进阶篇)
🔗 相关章节:长短时记忆网络 LSTM/GRU · 注意力机制


要是你玩过翻译软件、跟AI聊过天,或者看过一张猫的图被生成成文字描述,大概率你已经见过序列到序列(Seq2Seq)模型的身影了!它是深度学习里处理变长输入→变长输出任务的经典架构,也是后来火遍AI圈的Transformer的老祖宗之一,啃透它能帮你搞懂Attention机制为什么“必须横空出世”。


1. Seq2Seq 是什么?

1.1 一句话定义+主流场景

Seq2Seq 的核心逻辑非常直观:

输入任意长度的序列 → 模型“消化”成统一载体 → 输出任意长度的目标序列

它覆盖的应用场景几乎涉及所有“序列转化”类任务,举几个大家耳熟能详的:

├── 机器翻译:中文“你好,NLP爱好者!” → 英文“Hello, NLP lovers!”
├── 文本摘要:2000字科技新闻 → 300字核心要点
├── 对话系统:用户“今天天气怎么样?” → 机器人回复
├── 代码生成:自然语言“写一个Python列表去重的函数” → Python 代码块
├── 语音识别:一段30秒的中文音频 → 对应文字转录
└── 图像描述(变体):一张猫跳栏杆的图 → “一只橘猫正在跳过白色的小栏杆”

1.2 经典 Encoder-Decoder 架构拆解

别看场景多,核心结构永远是两个循环神经网络(RNN/LSTM/GRU,早期多为单向RNN,后来常用双向LSTM) 搭起来的组合拳:

Seq2Seq = 负责“读”的 Encoder + 负责“写”的 Decoder + 连接两者的 Context Vector

我们拿中文→英文翻译(“我 爱 NLP”→“I love NLP”) 举个具象的例子:

  1. Encoder(编码器):把输入的中文词序列逐个喂进去,最后一步输出的隐藏状态(Hidden State)或者“隐藏状态+细胞状态”(如果用LSTM/GRU),会被打包成一个固定长度的 Context Vector(上下文向量)
    • 这个向量是模型对整个输入序列的“压缩记忆”。
  2. Decoder(解码器):把 Context Vector 作为初始状态,再从 <START> 特殊token开始,逐词生成目标英文序列,直到遇到 <END> 特殊token停止。

2. PyTorch 极简 Seq2Seq 实现

光说不练假把式!我们用PyTorch写一个基于双向LSTM的Encoder+简单单向LSTM的Decoder的小模型,附带最常用的两种解码方法。

2.1 完整的基础模型代码

import torch
import torch.nn as nn
import random

# --------------------------
# 双向LSTM Encoder
# --------------------------
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2):
        super().__init__()
        # 词嵌入层:把token id转成向量
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # 双向LSTM:能同时看到输入词的“上文”和“下文”
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, bidirectional=True, dropout=0.2 if num_layers>1 else 0
        )
        # 把双向最后一层的隐藏/细胞状态拼接成单方向,适配单向Decoder
        self.hidden_fc = nn.Linear(hidden_dim * 2, hidden_dim)
        self.cell_fc = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, input_ids):
        # input_ids shape: (batch_size, seq_len)
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embed_dim)
        
        # outputs: 所有时间步的输出;(hidden, cell): 最后时间步的状态
        outputs, (hidden, cell) = self.lstm(embedded)
        
        # 拼接双向的最后隐藏层(hidden[-2]是正向最后,hidden[-1]是反向最后)
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=-1)
        cell = torch.cat([cell[-2], cell[-1]], dim=-1)
        # 压缩成单向Decoder需要的维度
        hidden = self.hidden_fc(hidden).unsqueeze(0)  # (1, batch_size, hidden_dim)
        cell = self.cell_fc(cell).unsqueeze(0)

        return outputs, (hidden, cell)


# --------------------------
# 带简单拼接Context的单向LSTM Decoder
# --------------------------
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # 输入是当前词嵌入 + 固定Context Vector
        self.lstm = nn.LSTM(
            embed_dim + hidden_dim, hidden_dim, num_layers,
            batch_first=True, dropout=0.2 if num_layers>1 else 0
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)  # 输出每个词的概率logits

    def forward(self, input_t, hidden, cell, context):
        # input_t shape: (batch_size, 1) → 当前单个词
        embedded = self.embedding(input_t)  # (batch_size, 1, embed_dim)
        # 拼接上下文向量
        lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=-1)
        
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        logits = self.fc(output.squeeze(1))  # (batch_size, vocab_size)
        return logits, hidden, cell


# --------------------------
# 完整 Seq2Seq 模型
# --------------------------
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_ids, target_ids, teacher_forcing_ratio=0.5):
        batch_size = input_ids.size(0)
        target_len = target_ids.size(1)
        target_vocab_size = self.decoder.fc.out_features

        # 预分配输出的logits矩阵
        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(input_ids.device)

        # 先过Encoder拿到压缩状态和Context
        _, (hidden, cell) = self.encoder(input_ids)
        context = hidden[-1]  # 取最后一层的隐藏状态作为固定Context

        # 解码第一步:用<START> token
        decoder_input = target_ids[:, 0:1]

        # 逐词生成
        for t in range(target_len):
            logits, hidden, cell = self.decoder(decoder_input, hidden, cell, context)
            outputs[:, t] = logits

            # Teacher Forcing策略:随机选择用真实标签还是上一步预测结果
            teacher_force = random.random() < teacher_forcing_ratio
            top1_token = logits.argmax(1)
            decoder_input = target_ids[:, t:t+1] if teacher_force else top1_token.unsqueeze(1)

        return outputs

2.2 两种核心解码策略

模型训练好后,如何从输出的logits生成合理的目标序列? 这里介绍两种最常用的方法:

① 贪婪解码(Greedy Decode)

最“简单粗暴”的策略:每一步只选概率最高的词,直到遇到<END>或达到最大长度。

def greedy_decode(model, input_ids, tokenizer, max_len=50, start_token=2, end_token=3):
    model.eval()
    with torch.no_grad():
        # 先过Encoder
        _, (hidden, cell) = model.encoder(input_ids)
        context = hidden[-1]

        # 初始化:只有<START>
        decoder_input = torch.tensor([[start_token]]).to(input_ids.device)
        generated_tokens = []

        for _ in range(max_len):
            logits, hidden, cell = model.decoder(decoder_input, hidden, cell, context)
            # 取概率最高的词
            next_token = logits.argmax(1).item()

            if next_token == end_token:
                break
            generated_tokens.append(next_token)
            decoder_input = torch.tensor([[next_token]]).to(input_ids.device)

        # 用tokenizer把id转成文字
        return tokenizer.decode(generated_tokens)

贪婪解码的缺点是容易陷入局部最优(比如第一步选了概率最高但后面整体不通顺的词)。Beam Search通过维护top-k个“当前最优候选序列” 来缓解这个问题,k称为“beam size”。

def beam_search_decode(model, input_ids, tokenizer, max_len=50,
                        beam_size=5, start_token=2, end_token=3):
    model.eval()
    with torch.no_grad():
        # 先过Encoder
        _, (hidden, cell) = model.encoder(input_ids)
        context = hidden[-1]

        # 初始化beam列表:每个beam记录tokens、总得分、当前隐藏/细胞状态
        beams = [
            {"tokens": [start_token], "score": 0.0, "hidden": hidden, "cell": cell}
        ]

        for _ in range(max_len):
            all_candidates = []
            # 对每个当前beam展开
            for beam in beams:
                # 如果已经遇到<END>,直接保留这个候选
                if beam["tokens"][-1] == end_token:
                    all_candidates.append(beam)
                    continue

                # 预测下一个词
                decoder_input = torch.tensor([[beam["tokens"][-1]]]).to(input_ids.device)
                logits, new_hidden, new_cell = model.decoder(
                    decoder_input, beam["hidden"], beam["cell"], context
                )
                # 转成log概率(避免连乘下溢,用加法更稳定)
                log_probs = torch.log_softmax(logits, dim=-1)
                # 取top-k个候选词
                topk_log_probs, topk_tokens = log_probs.topk(beam_size, dim=-1)

                # 生成k个新候选
                for i in range(beam_size):
                    token = topk_tokens[0, i].item()
                    total_score = beam["score"] + topk_log_probs[0, i].item()
                    all_candidates.append({
                        "tokens": beam["tokens"] + [token],
                        "score": total_score,
                        "hidden": new_hidden,
                        "cell": new_cell,
                    })

            # 按总得分降序排序,取top-k保留为新的beam
            all_candidates.sort(key=lambda x: x["score"], reverse=True)
            beams = all_candidates[:beam_size]

        # 最后选得分最高的,去掉开头的<START>再转文字
        return tokenizer.decode(beams[0]["tokens"][1:])

3. 经典Seq2Seq的致命问题→引出Attention

我们刚才的实现里,用了一个固定长度的Context Vector来压缩整个输入序列——这就是经典Seq2Seq最大的信息瓶颈

比如翻译一篇1000字的中文文章,还是用512维的Context Vector,后面的信息肯定会被前面的“淹没”,模型生成到后半段时已经“忘得差不多了”。

→ 怎么解决?Attention机制!让解码器在生成每个词时,能“主动看”输入序列里最相关的几个词,不再依赖单一的固定Context。这也是我们下一篇(注意力机制)要重点讲的内容。

💡 小总结

  1. 经典Seq2Seq = 双向Encoder压缩 + 单向Decoder逐词生成 + Teacher Forcing训练加速
  2. 解码常用贪婪(快但易局部最优)和Beam Search(慢但更通顺)
  3. 它是理解Transformer的必经之路,但现在纯Seq2Seq已经基本被Transformer取代

🔗 扩展阅读与论文