import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------
# 1. CNN特征提取器:裁剪VGG,专门适配文本识别
# ---------------------------------------------------------
class VGGTextBackbone(nn.Module):
"""
输入高度必须固定为32,宽度可变;输出高压缩为1,宽压缩为≈W/4±1
"""
def __init__(self, in_channels=1, out_channels=512):
super().__init__()
self.backbone = nn.Sequential(
# Block 1: 32×W → 16×W/2
nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(True),
nn.MaxPool2d(2, 2),
# Block 2: 16×W/2 → 8×W/4
nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
nn.MaxPool2d(2, 2),
# Block 3: 8×W/4 → 4×(W/4+2)(宽stride=1,保留序列长度)
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
nn.MaxPool2d((2,2), (2,1), (0,1)),
# Block 4: 4×(W/4+2) → 1×(W/4+1)(高压缩为1)
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
nn.MaxPool2d((2,2), (2,1), (0,0)),
# Block 5: 进一步降维,宽再减1(≈W/4)
nn.Conv2d(512, out_channels, 2, 1, 0), nn.ReLU(True)
)
def forward(self, x):
return self.backbone(x)
# ---------------------------------------------------------
# 2. 单层双向LSTM:自带降维线性层
# ---------------------------------------------------------
class BiLSTMEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.lstm = nn.LSTM(in_dim, hidden_dim, bidirectional=True, batch_first=False)
self.linear = nn.Linear(hidden_dim * 2, out_dim) # 双向输出拼接后降维
def forward(self, x):
# x shape: (Time_steps, Batch, in_dim)
lstm_out, _ = self.lstm(x) # lstm_out: (Time_steps, Batch, hidden*2)
# 时间步+批量 展平给线性层,恢复形状
T, B, H2 = lstm_out.shape
return self.linear(lstm_out.reshape(T*B, H2)).reshape(T, B, -1)
# ---------------------------------------------------------
# 3. 完整CRNN模型
# ---------------------------------------------------------
class CRNN(nn.Module):
def __init__(self, img_h=32, in_channels=1, nclass=27, hidden_dim=256):
"""
img_h: 固定为32(否则VGG下采样后高度不为1)
nclass: 必须包含【空白符(0) + 目标字符集(1-N)】
"""
super().__init__()
assert img_h % 16 == 0, "img_h必须是16的倍数"
self.backbone = VGGTextBackbone(in_channels)
self.rnn1 = BiLSTMEncoder(512, hidden_dim, hidden_dim)
self.rnn2 = BiLSTMEncoder(hidden_dim, hidden_dim, nclass)
def forward(self, x):
# Step 1: CNN提取特征
conv = self.backbone(x) # (B, 512, 1, W_seq)
B, C, H, W_seq = conv.shape
assert H == 1, "CNN输出高度必须为1"
# Step 2: 特征转序列(关键!)
conv = conv.squeeze(2) # (B, C, W_seq)
conv = conv.permute(2, 0, 1) # (W_seq, B, C) → 符合PyTorch RNN输入格式
# Step 3: 序列预测
rnn_out = self.rnn1(conv)
return self.rnn2(rnn_out)
# ---------------------------------------------------------
# 4. 模型测试
# ---------------------------------------------------------
if __name__ == "__main__":
# 假设识别小写英文a-z,加上空白符共27类
model = CRNN(nclass=27)
dummy_img = torch.randn(1, 1, 32, 100) # 1张32×100的灰度图
with torch.no_grad():
output = model(dummy_img)
print(f"输入形状: {dummy_img.shape}")
print(f"输出形状: {output.shape}") # 应该是 (24, 1, 27),24是时间步长