序列到序列模型 (Seq2Seq):Encoder-Decoder 架构

📂 所属阶段:第二阶段 — 深度学习与序列模型(进阶篇)
🔗 相关章节:长短时记忆网络 LSTM/GRU · 注意力机制


1. Seq2Seq 是什么?

1.1 应用场景

Seq2Seq = 输入任意长度的序列 → 输出任意长度的序列

应用场景:
├── 机器翻译:你好 → Hello
├── 文本摘要:长文章 → 短摘要
├── 对话系统:用户问题 → 机器人回复
├── 代码生成:自然语言 → Python 代码
├── 语音识别:音频 → 文字
└── 图像描述:图片 → 文字描述

1.2 Encoder-Decoder 架构

Seq2Seq = Encoder(编码器)+ Decoder(解码器)

Encoder:
  "我 爱 NLP" → RNN → Context Vector(上下文向量)
  
Decoder:
  Context Vector → RNN → "I love NLP"

核心:用一个固定长度的向量(Context)压缩整个输入序列!

2. PyTorch Seq2Seq 实现

2.1 完整实现

import torch
import torch.nn as nn
import random

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
                           batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        outputs, (hidden, cell) = self.lstm(embedded)

        # 拼接双向最后隐藏状态作为初始解码器状态
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=-1)
        cell = torch.cat([cell[-2], cell[-1]], dim=-1)
        hidden = self.fc(hidden).unsqueeze(0)
        cell = self.fc(cell).unsqueeze(0)

        return outputs, (hidden, cell)


class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim + hidden_dim, hidden_dim,
                           num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_t, hidden, cell, context):
        # input_t: (batch, 1) 单个词
        embedded = self.embedding(input_t)
        # 将上下文向量与嵌入拼接
        lstm_input = torch.cat([embedded, context], dim=-1)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        logits = self.fc(output.squeeze(1))
        return logits, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_ids, target_ids, teacher_forcing_ratio=0.5):
        batch_size = input_ids.size(0)
        target_len = target_ids.size(1)
        target_vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(input_ids.device)

        _, (hidden, cell) = self.encoder(input_ids)
        context = hidden[-1].unsqueeze(1).repeat(1, target_len, 1)

        # 解码:逐词生成
        decoder_input = target_ids[:, 0:1]  # <START> token

        for t in range(target_len):
            logits, hidden, cell = self.decoder(decoder_input, hidden, cell, context)
            outputs[:, t] = logits

            # Teacher Forcing:随机用真实标签
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = logits.argmax(1)
            decoder_input = target_ids[:, t:t+1] if teacher_force else top1.unsqueeze(1)

        return outputs
def greedy_decode(model, input_ids, tokenizer, max_len=50, start_token=2, end_token=3):
    """贪婪解码:每步选概率最高的词"""
    model.eval()
    with torch.no_grad():
        _, (hidden, cell) = model.encoder(input_ids)
        context = hidden[-1].unsqueeze(1).repeat(1, max_len, 1)

        decoder_input = torch.tensor([[start_token]]).to(input_ids.device)
        generated = []

        for _ in range(max_len):
            logits, hidden, cell = model.decoder(decoder_input, hidden, cell, context)
            next_token = logits.argmax(1).item()

            if next_token == end_token:
                break
            generated.append(next_token)
            decoder_input = torch.tensor([[next_token]]).to(input_ids.device)

        return tokenizer.decode(generated)


def beam_search_decode(model, input_ids, tokenizer, max_len=50,
                        beam_size=5, start_token=2, end_token=3):
    """Beam Search:维护多个候选"""
    model.eval()
    with torch.no_grad():
        _, (hidden, cell) = model.encoder(input_ids)
        context = hidden[-1].unsqueeze(1).repeat(1, max_len, 1)

        # 初始化:每个 beam 一个候选
        beams = [{"tokens": [start_token], "score": 0, "hidden": hidden, "cell": cell}]

        for _ in range(max_len):
            all_candidates = []
            for beam in beams:
                if beam["tokens"][-1] == end_token:
                    all_candidates.append(beam)
                    continue

                decoder_input = torch.tensor([[beam["tokens"][-1]]]).to(input_ids.device)
                logits, new_hidden, new_cell = model.decoder(
                    decoder_input, beam["hidden"], beam["cell"], context
                )
                log_probs = torch.log_softmax(logits, dim=-1)
                topk_probs, topk_tokens = log_probs.topk(beam_size, dim=-1)

                for i in range(beam_size):
                    token = topk_tokens[0, i].item()
                    prob = topk_probs[0, i].item()
                    candidate = {
                        "tokens": beam["tokens"] + [token],
                        "score": beam["score"] + prob,
                        "hidden": new_hidden,
                        "cell": new_cell,
                    }
                    all_candidates.append(candidate)

            # 取最高分的 beam
            all_candidates.sort(key=lambda x: x["score"], reverse=True)
            beams = all_candidates[:beam_size]

        return tokenizer.decode(beams[0]["tokens"][1:])

3. 小结

Seq2Seq = Encoder + Decoder + Context Vector

问题:用一个固定向量压缩任意长度序列 → 信息瓶颈!

→ 解决:Attention!让解码器"看到"输入的每一步

💡 记住:Seq2Seq 是 Transformer 之前的 Seq2Seq 主流方案,现在被 Transformer 完全取代。理解 Seq2Seq 是理解 Attention 机制的基础。


🔗 扩展阅读