循环神经网络 (RNN):处理序列数据的逻辑

📂 所属阶段:第二阶段 — 深度学习与序列模型(进阶篇)
🔗 相关章节:PyTorch 基础 · 长短时记忆网络 LSTM/GRU


1. 为什么需要 RNN?

1.1 传统神经网络的局限

传统 NN(CNN/Dense):每个输入独立处理,无视序列顺序

文本序列:"这部电影 好看" vs "这部电影 难看"
→ 传统 NN 对两个句子输出相同的特征向量
→ 无法区分"好看"和"难看"在语义上的巨大差异

RNN:引入"记忆",让网络记住之前看到的内容
→ "好看"出现在"不错"之后 → 整体情感偏正面
→ "难看"出现在"太差"之后 → 整体情感偏负面

1.2 RNN 的核心思想

RNN = 循环 + 隐藏状态(记忆)

展开图(按时间步展开):
        x₀        x₁        x₂
        ↓         ↓         ↓
      ┌───┐    ┌───┐    ┌───┐
 h₀→ │ RNN │→ h₁→│ RNN │→ h₂→│ RNN │
      └───┘    └───┘    └───┘    └───┘
              ↓         ↓         ↓
              y₀        y₁        y₂

每个时间步:
  h_t = tanh(W_xh · x_t + W_hh · h_{t-1} + b)

  - x_t:当前输入
  - h_{t-1}:上一步的隐藏状态(记忆)
  - h_t:新的隐藏状态
  - y_t:当前输出

2. RNN 的问题

2.1 梯度消失与梯度爆炸

梯度消失问题:
→ 当序列很长时,反向传播的梯度经过多层连乘 → 趋近于 0
→ 模型无法学习长距离依赖

梯度爆炸问题:
→ 梯度经过多层连乘 → 趋近于无穷大
→ 训练不稳定,loss 爆炸

这就是为什么 RNN 难以记住"很久之前"的信息!

2.2 长期依赖问题

# 长期依赖困难
# "我出生在中国……(1000字后)……我会说中文" → RNN 忘记了"中国"
# "我出生在日本……(1000字后)……我会说___" → 很难填出"日语"

3. PyTorch RNN 实现

3.1 基础文本分类

import torch
import torch.nn as nn

class RNNClassifier(nn.Module):
    """基于 RNN 的文本分类器"""
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=False,  # 单向
        )
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        embedded = self.embedding(input_ids)  # (batch, seq_len, embed_dim)

        # RNN
        output, hidden = self.rnn(embedded)
        # output: 每个时间步的隐藏状态
        # hidden: 最后一个时间步的隐藏状态(包含所有序列信息)

        # 取最后一个时间步
        last_hidden = output[:, -1, :]
        logits = self.fc(last_hidden)
        return logits

# 使用
model = RNNClassifier(vocab_size=10000)
input_ids = torch.randint(1, 10000, (32, 50))  # batch=32, seq_len=50
logits = model(input_ids)
print(logits.shape)  # torch.Size([32, 2])

3.2 双向 RNN

class BiRNNClassifier(nn.Module):
    """双向 RNN:同时看前向和后向"""
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,  # 关键!双向
            dropout=0.3,
        )
        # hidden_dim * 2 因为双向
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        output, (forward_h, backward_h) = self.rnn(embedded)

        # 拼接双向最后隐藏状态
        # forward_h[-1]: 最后一个前向隐藏状态
        # backward_h[-1]: 最后一个后向隐藏状态
        last_hidden = torch.cat([forward_h[-1], backward_h[-1]], dim=-1)
        logits = self.fc(last_hidden)
        return logits

4. 小结

# RNN 速查

# PyTorch RNN
nn.RNN(input_size, hidden_size, num_layers,
       batch_first=True, bidirectional=True, dropout=0.3)

output, hidden = rnn(input)
# output: (batch, seq, hidden*dirs)
# hidden: (num_layers*dirs, batch, hidden)

# 问题
# 梯度消失 → 长序列记不住 → 解决:LSTM / GRU
# 梯度爆炸 → 训练不稳定 → 解决:梯度裁剪

💡 记住:RNN 本身由于梯度消失/爆炸问题,在实践中很少直接使用。真正起作用的是它的改进版本——LSTM 和 GRU。


🔗 扩展阅读