长短时记忆网络 (LSTM/GRU):解决梯度消失,捕捉长距离依赖

📂 所属阶段:第二阶段 — 深度学习与序列模型(进阶篇)
🎯 前置知识:循环神经网络 (RNN) 基础
🔗 相关章节:循环神经网络 (RNN) · 序列到序列模型 (Seq2Seq)


1. LSTM 核心思想:给RNN装个“信息保险箱”

1.1 痛点:普通RNN的“短视症”

当你让一个普通RNN去读一篇长文章时,它往往会很快“忘记”开头的内容。比如分析一条电影评论——“《开端》前20分钟有点拖沓,但后面全程高能,最后10集根本停不下来”,普通RNN很可能只记住了结尾的“停不下来”,开头“前20分钟有点拖沓”这个负面信号在反向传播时几乎消失殆尽,这就是梯度消失带来的典型问题。

LSTM的设计目标正是解决这种“健忘症”。它引入了一个贯穿整个序列的细胞状态(Cell State),你可以把它想象成一条“信息传送带”。传送带上可以稳定承载长期记忆,再配合三个可学习的“门”,来决定:

  • 哪些旧信息该从传送带上遗忘
  • 哪些新信息该被写入传送带?
  • 在最终输出时,传送带上的哪些内容该被拿出来用

这三个门就像是数据流的交通信号灯,让模型能够极其细腻地控制信息的流动。

1.2 拆解LSTM的计算流程

为方便理解,我们以情感分析任务为例,逐步跟踪一条评论的处理过程:“这部电影开场有点闷,但结局太治愈太戳泪点了!”

第一步:遗忘门——清理历史记忆

遗忘门要决定“细胞状态里哪些旧信息应该扔掉”。比如读到“但结局”这个词时,模型需要意识到前文“有点闷”的权重应该被降低甚至删除。
具体做法:将上一时刻的隐藏状态 h_prev(上一步的临时记忆)和当前输入 x_now(词向量)拼在一起,经过一组可学习的权重处理,再送给一个数值范围在 [0,1] 之间的“开关函数”。

  • 输出接近 1 表示“完全保留”;
  • 输出接近 0 表示“可以忘了”;
  • 中间值表示“部分保留”。

第二步:输入门 + 候选状态——准备新信息

这一阶段决定“往细胞状态里加入什么新知识”,由两个配件协同运作:

  1. 输入选择门:依然拼接 h_prevx_now,用开关函数选出哪些新信息值得记忆。
  2. 候选状态:同样拼接 h_prevx_now,但改用数值范围在 [-1,1] 的激活函数,生成一个“新内容草稿”。
  3. 两者相乘——只把被输入门“点亮”的内容真正写入传送带。

第三步:更新细胞状态——刷新传送带

这是LSTM最核心的计算:

  1. 遗忘门的输出乘以旧的细胞状态(上一时刻传送带上的内容);
  2. 再加上输入门的写入结果
  3. 得到更新后的细胞状态。

这样一来,不重要旧信息被遗忘,新鲜重要信息被写入,传送带始终携带当前最关键的全局记忆。

第四步:输出门 + 隐藏状态——决定向下一层输出什么

最后,模型决定“从传送带里挑选哪些信息用于生成当前的输出(隐藏状态)”:

  1. 输出选择门:同样拼接 h_prevx_now,通过开关函数选出传送带中该暴露的部分。
  2. 归一化传送带内容:将细胞状态通过 [-1,1] 激活函数压缩一下,避免数值过大。
  3. 两者相乘,得到当前时刻的隐藏状态——它既包含此时此刻最重要信息,也带有长距离依赖的内容,会传给下一时刻或后续的全连接分类层。

1.3 PyTorch LSTM 实战:双向情感分类器

以下实现了一个完整的双向LSTM文本分类模型。双向LSTM可以同时从左到右和从右到左扫描序列,对情感分析这类需要全局理解的任务效果更好。

import torch
import torch.nn as nn

class BiLSTMTextClassifier(nn.Module):
    """
    双向LSTM文本分类器
    适用于情感分析、新闻分类等短文本任务
    """
    def __init__(
        self,
        vocab_size: int,    # 词表大小
        embed_dim: int = 256,  # 词嵌入维度
        hidden_dim: int = 256, # LSTM隐藏层维度
        num_layers: int = 2,   # LSTM层数
        dropout: float = 0.3,  # dropout比例,防止过拟合
        num_classes: int = 2    # 分类类别数(二分类:积极/消极)
    ):
        super().__init__()
        
        # 1. 词嵌入层:把词ID转换成低维稠密向量
        self.embedding = nn.Embedding(
            vocab_size, embed_dim, padding_idx=0  # padding_idx=0:忽略词表中的填充词
        )
        
        # 2. 双向LSTM层
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,   # 输入输出的第一维度是batch_size(更符合习惯)
            bidirectional=True, # 双向:前向看前文,后向看后文
            dropout=dropout if num_layers > 1 else 0  # 只有多层LSTM才加层间dropout
        )
        
        # 3. 全连接分类头:双向拼接后的维度是 hidden_dim*2
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        input_ids: (batch_size, seq_len) → 每个样本是词ID序列
        logits: (batch_size, num_classes) → 每个样本对应类别的未归一化分数
        """
        # 词嵌入:(B, L) → (B, L, E)
        embedded = self.embedding(input_ids)
        
        # LSTM计算
        # output: (B, L, 2H) → 每个时刻的双向隐藏状态拼接
        # (h_n, c_n): 最后时刻的隐藏状态和细胞状态,(2*num_layers, B, H)
        output, (h_n, _) = self.lstm(embedded)
        
        # 取最后一层的双向隐藏状态拼接 → (B, 2H)
        last_layer_forward = h_n[-2]  # 前向最后一层的最后一个隐藏状态
        last_layer_backward = h_n[-1] # 后向最后一层的最后一个隐藏状态
        final_hidden = torch.cat([last_layer_forward, last_layer_backward], dim=-1)
        
        # 分类头计算
        logits = self.classifier(final_hidden)
        return logits

2. GRU:LSTM的“轻量精简版”

2.1 GRU的改进思路

2014年,Cho等人提出了门控循环单元(GRU),用更简洁的结构实现了与LSTM类似的效果。GRU把LSTM的三个门合并成两个,并且去掉了独立的细胞状态——它用一种巧妙的方式把“长期记忆”和“短期临时记忆”融合到统一的隐藏状态中。

特性LSTMGRU
门控数量遗忘门、输入门、输出门重置门、更新门(两个)
状态结构细胞状态 + 隐藏状态单一隐藏状态
参数量较大少约30%
表现与效率复杂任务上更稳定中小型任务效果相近,训练/推理更快

2.2 PyTorch GRU 实战:同样任务,更轻量的选择

将上面的LSTM模型替换成GRU非常简单,只需要把 nn.LSTM 换成 nn.GRU,同时注意GRU不返回细胞状态 c_n 即可。

import torch
import torch.nn as nn

class BiGRUTextClassifier(nn.Module):
    """
    双向GRU文本分类器
    轻量高效,适合快速原型验证或简单任务
    """
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 256,
        hidden_dim: int = 256,
        num_layers: int = 2,
        dropout: float = 0.3,
        num_classes: int = 2
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # 替换成nn.GRU
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # 分类头不变
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        embedded = self.embedding(input_ids)
        # GRU只返回output和h_n,没有c_n
        _, h_n = self.gru(embedded)
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=-1)
        logits = self.classifier(final_hidden)
        return logits

3. 实战片段:快速跑通情感分类训练

3.1 训练与验证单轮函数

为了让模型训练更稳定,我们通常还会使用梯度裁剪来防止梯度爆炸,并同步计算准确率。

import torch
from torch.utils.data import DataLoader

def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    clip_max_norm: float = 1.0  # 梯度裁剪的最大范数
) -> tuple[float, float]:
    """
    训练单轮模型
    返回:平均损失、平均准确率
    """
    model.train()
    total_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for batch in dataloader:
        # 把数据移到GPU/CPU
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)

        # 前向传播
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits, labels)

        # 反向传播 + 梯度裁剪 + 参数更新
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

        # 统计指标
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    avg_acc = correct_preds / total_preds
    return avg_loss, avg_acc

3.2 极简推理函数

推理时只需要加载训练好的模型,关闭梯度计算并调用 softmax 即可得到每个类别的概率。

import torch

def predict_sentiment(
    model: nn.Module,
    tokenizer,  # 假设已初始化好的Tokenizer
    text: str,
    device: torch.device
) -> dict[str, float]:
    """
    预测单条文本的情感
    返回:积极、消极的概率字典
    """
    model.eval()
    with torch.no_grad():  # 推理时不需要计算梯度
        # 分词+转ID+补填充(这里简化,实际用tokenizer的__call__更方便)
        tokens = tokenizer.tokenize(text)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        # 加batch维度 → (1, seq_len)
        input_ids = torch.tensor([ids]).to(device)
        
        # 前向传播
        logits = model(input_ids)
        # 用softmax把未归一化分数转成概率
        probs = torch.softmax(logits, dim=-1).squeeze(0)  # 去掉batch维度
    
    return {
        "positive": round(probs[1].item(), 4),
        "negative": round(probs[0].item(), 4)
    }

4. 2026年的选择建议

4.1 LSTM vs GRU 简单对比

维度LSTMGRU
参数量较多少约30%
训练/推理速度相对较慢更快
长距离记忆能力理论上略强中等/简单任务足够
历史典型应用机器翻译、语音识别情感分析、文本分类

4.2 2026年的实际应用情况

必须诚实地说一句:虽然LSTM/GRU是每个深度学习从业者入门序列建模的“必修课”,但在当下的NLP和语音领域,基于Transformer架构的预训练模型(BERT、GPT、Whisper等)已经基本占据了主流地位。这些模型通过海量无监督预训练获得了强大的通用表示,配合精细的下游微调,性能远超从零训练的LSTM/GRU。而且,随着A100、H100等硬件对自注意力机制的深度优化,Transformer的训练效率甚至可以不逊色于多层LSTM。

4.3 什么时候还会用LSTM/GRU?

尽管如此,LSTM和GRU在以下场景中依然活跃:

  1. 边缘设备与低延迟场景:参数规模小,推理速度快,对硬件要求低。
  2. 顺序性极强的小型任务:比如某些传感器时序异常检测、小众语言的低资源词性标注等。
  3. 学术基线对比:做研究时,LSTM/GRU是最经典的对比模型之一,衡量新方法有效性的重要参照。

🔗 扩展阅读