import torch
import torch.nn as nn
import random
# --------------------------
# 双向LSTM Encoder
# --------------------------
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2):
super().__init__()
# 词嵌入层:把token id转成向量
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# 双向LSTM:能同时看到输入词的“上文”和“下文”
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, bidirectional=True, dropout=0.2 if num_layers>1 else 0
)
# 把双向最后一层的隐藏/细胞状态拼接成单方向,适配单向Decoder
self.hidden_fc = nn.Linear(hidden_dim * 2, hidden_dim)
self.cell_fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, input_ids):
# input_ids shape: (batch_size, seq_len)
embedded = self.embedding(input_ids) # (batch_size, seq_len, embed_dim)
# outputs: 所有时间步的输出;(hidden, cell): 最后时间步的状态
outputs, (hidden, cell) = self.lstm(embedded)
# 拼接双向的最后隐藏层(hidden[-2]是正向最后,hidden[-1]是反向最后)
hidden = torch.cat([hidden[-2], hidden[-1]], dim=-1)
cell = torch.cat([cell[-2], cell[-1]], dim=-1)
# 压缩成单向Decoder需要的维度
hidden = self.hidden_fc(hidden).unsqueeze(0) # (1, batch_size, hidden_dim)
cell = self.cell_fc(cell).unsqueeze(0)
return outputs, (hidden, cell)
# --------------------------
# 带简单拼接Context的单向LSTM Decoder
# --------------------------
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# 输入是当前词嵌入 + 固定Context Vector
self.lstm = nn.LSTM(
embed_dim + hidden_dim, hidden_dim, num_layers,
batch_first=True, dropout=0.2 if num_layers>1 else 0
)
self.fc = nn.Linear(hidden_dim, vocab_size) # 输出每个词的概率logits
def forward(self, input_t, hidden, cell, context):
# input_t shape: (batch_size, 1) → 当前单个词
embedded = self.embedding(input_t) # (batch_size, 1, embed_dim)
# 拼接上下文向量
lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=-1)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
logits = self.fc(output.squeeze(1)) # (batch_size, vocab_size)
return logits, hidden, cell
# --------------------------
# 完整 Seq2Seq 模型
# --------------------------
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, input_ids, target_ids, teacher_forcing_ratio=0.5):
batch_size = input_ids.size(0)
target_len = target_ids.size(1)
target_vocab_size = self.decoder.fc.out_features
# 预分配输出的logits矩阵
outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(input_ids.device)
# 先过Encoder拿到压缩状态和Context
_, (hidden, cell) = self.encoder(input_ids)
context = hidden[-1] # 取最后一层的隐藏状态作为固定Context
# 解码第一步:用<START> token
decoder_input = target_ids[:, 0:1]
# 逐词生成
for t in range(target_len):
logits, hidden, cell = self.decoder(decoder_input, hidden, cell, context)
outputs[:, t] = logits
# Teacher Forcing策略:随机选择用真实标签还是上一步预测结果
teacher_force = random.random() < teacher_forcing_ratio
top1_token = logits.argmax(1)
decoder_input = target_ids[:, t:t+1] if teacher_force else top1_token.unsqueeze(1)
return outputs