CRNN详解:端到端不定长文字识别模型

当你用手机扫描快递单自动填地址、停车场摄像头秒读车牌缴费、PDF转Word提取纯文本时,背后大概率有个高效的不定长文本识别引擎——而CRNN正是这类引擎的开山鼻祖和工业级应用基石之一。


引言

在光学字符识别(OCR)的早期,“先切分单字符再分类”是主流思路,但这种方案存在致命缺陷:

  • 依赖复杂的字符分割算法,无法处理粘连字符、模糊变形字符、自然场景倾斜字符
  • 标注成本极高,需要人工框选每一个字符
  • 无法处理非均匀长度、断句模糊的文本序列

2015年,Baoguang Shi等人提出的CRNN(Convolutional Recurrent Neural Network)彻底打破了这一格局:它通过「CNN特征提取→BiLSTM序列建模→CTC对齐解码」的三段式架构,首次实现了完全端到端的不定长文本序列识别,无需任何字符级切分与标注。


1. CRNN模型概述

1.1 核心三段式逻辑

flowchart LR
    A[输入图像<br/>32×W×1/RGB] --> B[CNN特征提取<br/>高维度→序列化特征图<br/>(1×W_seq×512)]
    B --> C[BiLSTM序列建模<br/>捕获上下文依赖<br/>(W_seq×Batch×Hidden*2)]
    C --> D[线性分类层<br/>(W_seq×Batch×nclass)]
    D --> E[CTC贪婪解码<br/>得到最终文本]

一句话总结:把CNN从「图像分类/检测」的工具,转化为「为序列模型喂视觉时间步」的特征生产者,再用BiLSTM补全字符间的语言/结构关联,最后靠CTC解决输出与标签的长度不匹配问题

1.2 核心工业级优势

  • 端到端训练:只需要「图像→文本」的配对数据
  • 处理任意宽高比:高度固定32,宽度可以无限伸缩(只要输入序列长度≥目标文本长度)
  • 轻量高效:推理速度是Transformer-based模型的3-10倍,适合边缘设备部署
  • 依赖少:不需要词典辅助(有词典能提升,但不是必须)
  • 可解释性强:每个时间步对应图像上的一列像素,便于调试错误

2. PyTorch极简实现

为了让你快速上手,这里提供一个裁剪优化版VGG+双层BiLSTM+标准CTC兼容的PyTorch实现,代码仅200行左右,完全可训练可推理。

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------------------
# 1. CNN特征提取器:裁剪VGG,专门适配文本识别
# ---------------------------------------------------------
class VGGTextBackbone(nn.Module):
    """
    输入高度必须固定为32,宽度可变;输出高压缩为1,宽压缩为≈W/4±1
    """
    def __init__(self, in_channels=1, out_channels=512):
        super().__init__()
        self.backbone = nn.Sequential(
            # Block 1: 32×W → 16×W/2
            nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            
            # Block 2: 16×W/2 → 8×W/4
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            
            # Block 3: 8×W/4 → 4×(W/4+2)(宽stride=1,保留序列长度)
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)),
            
            # Block 4: 4×(W/4+2) → 1×(W/4+1)(高压缩为1)
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,0)),
            
            # Block 5: 进一步降维,宽再减1(≈W/4)
            nn.Conv2d(512, out_channels, 2, 1, 0), nn.ReLU(True)
        )

    def forward(self, x):
        return self.backbone(x)

# ---------------------------------------------------------
# 2. 单层双向LSTM:自带降维线性层
# ---------------------------------------------------------
class BiLSTMEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden_dim, bidirectional=True, batch_first=False)
        self.linear = nn.Linear(hidden_dim * 2, out_dim) # 双向输出拼接后降维

    def forward(self, x):
        # x shape: (Time_steps, Batch, in_dim)
        lstm_out, _ = self.lstm(x) # lstm_out: (Time_steps, Batch, hidden*2)
        
        # 时间步+批量 展平给线性层,恢复形状
        T, B, H2 = lstm_out.shape
        return self.linear(lstm_out.reshape(T*B, H2)).reshape(T, B, -1)

# ---------------------------------------------------------
# 3. 完整CRNN模型
# ---------------------------------------------------------
class CRNN(nn.Module):
    def __init__(self, img_h=32, in_channels=1, nclass=27, hidden_dim=256):
        """
        img_h: 固定为32(否则VGG下采样后高度不为1)
        nclass: 必须包含【空白符(0) + 目标字符集(1-N)】
        """
        super().__init__()
        assert img_h % 16 == 0, "img_h必须是16的倍数"
        self.backbone = VGGTextBackbone(in_channels)
        self.rnn1 = BiLSTMEncoder(512, hidden_dim, hidden_dim)
        self.rnn2 = BiLSTMEncoder(hidden_dim, hidden_dim, nclass)

    def forward(self, x):
        # Step 1: CNN提取特征
        conv = self.backbone(x) # (B, 512, 1, W_seq)
        B, C, H, W_seq = conv.shape
        assert H == 1, "CNN输出高度必须为1"
        
        # Step 2: 特征转序列(关键!)
        conv = conv.squeeze(2)    # (B, C, W_seq)
        conv = conv.permute(2, 0, 1) # (W_seq, B, C) → 符合PyTorch RNN输入格式
        
        # Step 3: 序列预测
        rnn_out = self.rnn1(conv)
        return self.rnn2(rnn_out)

# ---------------------------------------------------------
# 4. 模型测试
# ---------------------------------------------------------
if __name__ == "__main__":
    # 假设识别小写英文a-z,加上空白符共27类
    model = CRNN(nclass=27)
    dummy_img = torch.randn(1, 1, 32, 100) # 1张32×100的灰度图
    
    with torch.no_grad():
        output = model(dummy_img)
    
    print(f"输入形状: {dummy_img.shape}")
    print(f"输出形状: {output.shape}") # 应该是 (24, 1, 27),24是时间步长

3. 训练与推理快速指南

3.1 训练(CTC Loss使用细节)

PyTorch内置的nn.CTCLoss完全兼容CRNN输出,但要注意以下参数:

import torch.optim as optim

# 1. 初始化
model = CRNN(nclass=27)
criterion = nn.CTCLoss(blank=0, reduction='mean') # blank索引固定为0
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 2. 假设一次迭代的样本
# images: (Batch, 1, 32, W)
# targets: 是【所有标签的一维拼接】,例如['abc', 'de']→[1,2,3,4,5]
# target_lengths: [3, 2]
# input_lengths: 每个样本的时间步长(所有样本时间步长相同的话用torch.full)
images = torch.randn(2, 1, 32, 100)
targets = torch.tensor([1,2,3,4,5], dtype=torch.long)
target_lengths = torch.tensor([3,2], dtype=torch.long)
input_lengths = torch.full((2,), 24, dtype=torch.long)

# 3. 前向传播+计算损失
model_output = model(images)
log_probs = F.log_softmax(model_output, dim=2) # CTC必须用对数概率
loss = criterion(log_probs, targets, input_lengths, target_lengths)

# 4. 反向传播+更新
optimizer.zero_grad()
loss.backward()
optimizer.step()

3.2 推理(贪婪解码实现)

最简单的解码方式,无需词典,适合快速验证:

def ctc_greedy_decode(output_probs, idx2char, blank_idx=0):
    """
    output_probs: (Time_steps, nclass) → 推理时取单样本的输出
    idx2char: 索引→字符的字典,如{1:'a', 2:'b', ...}
    """
    # 1. 每个时间步取概率最大的索引
    pred_indices = output_probs.argmax(dim=1).cpu().numpy()
    
    # 2. 合并连续重复的非blank,去除所有blank
    decoded = []
    prev_idx = blank_idx
    for idx in pred_indices:
        if idx != blank_idx and idx != prev_idx:
            decoded.append(idx)
        prev_idx = idx
    
    # 3. 转文本
    return ''.join([idx2char[i] for i in decoded])

4. 实践建议

4.1 数据处理

  • 输入图像高度必须固定为32,宽度按图像原始比例缩放,长边不超过256/512(根据显存调整)
  • 灰度图效果通常优于RGB(除非字符颜色与背景有强彩色依赖)
  • 数据增强:随机轻微倾斜(-15°~15°)、随机拉伸(宽0.9-1.1)、添加高斯噪声/模糊、对比度调整,这4种对CRNN提升最大

4.2 模型部署

  • 边缘设备(手机/摄像头):用torch.onnx转ONNX,再用ONNX Runtime-TensorRT/NCNN/TNN加速,推理速度可达100fps+
  • 云端/服务器:直接用PyTorch推理或TensorRT加速即可

总结

CRNN是OCR领域从「传统切分」转向「端到端识别」的里程碑模型,虽然现在Transformer-based模型(如CRNN-Transformer、PARSeq、MASTER)在准确率上占优,但CRNN的轻量、高效、依赖少特点,至今仍是车牌识别、票据识别、文档行识别等标准化场景的首选。

建议先掌握本文的实现代码,再尝试在SynthText/IIIT5K/自己的数据集上训练,最后对比不同模型的效果!


1. CTC Loss的`blank`索引**必须放在字符集的第一位** 2. CNN输出的时间步长**必须≥目标文本的最大长度** 3. 推理前一定要做`log_softmax`?不,贪婪解码直接用`argmax`即可,但训练必须用对数概率

🔗 相关资源