端到端不定长文字识别:CRNN 模型详解与 PyTorch 实现

1. 前言

在 OCR(光学字符识别)领域,识别图像中的文字序列一直是一个核心挑战。传统的 OCR 方法通常分为两个阶段:先检测单个字符,然后再进行分类。这种方法需要对每个字符进行精细的切分,不仅繁琐,而且对于粘连字符或模糊文本的识别效果很差。

CRNN (Convolutional Recurrent Neural Network) 的提出改变了这一局面。它是由 Baoguang Shi 等人在 2015 年提出的,其核心思想是将卷积神经网络(CNN)提取的特征序列化,然后利用循环神经网络(RNN)处理序列信息,最后结合 CTC(Connectionist Temporal Classification)损失函数实现不定长序列的端到端训练和识别。

CRNN 的主要优点:

  • 端到端训练:不需要对字符进行单独切分和标注。
  • 处理不定长序列:无论是 3 个字符还是 10 个字符的文本,模型都能处理。
  • 结合上下文信息:RNN 能够捕捉字符之间的序列依赖关系,提高识别准确率(例如,识别出 "ca" 后,更有可能识别出 "t" 而不是 "f")。
  • 模型轻量高效:相比于基于 Attention 的模型,CRNN 训练和推理速度更快。

2. 网络概述

CRNN 的架构非常清晰,它结合了三种不同的神经网络技术,自底向上分为三层:

2.1 卷积层 (Convolutional Layers) - 特征提取

CRNN 的底部是一个标准的 CNN(通常使用标准 VGG 或其变体)。

  • 输入:灰度或 RGB 图像。
  • 输出:一个特征图(Feature Map)。
  • 作用:CNN 负责从输入图像中提取高维视觉特征。

2.2 循环层 (Recurrent Layers) - 序列建模

这是 CRNN 的核心创新点。模型将 CNN 输出的特征图转化为一个特征向量序列

  • 操作:特征图的每一列(或者特定宽度的区域)被视为序列中的一个“时间步(Time Step)”。
  • 网络:通常使用双向 LSTM(Bidirectional LSTM, BiLSTM)。BiLSTM 能够同时捕获序列的前向和后向上下文信息。
  • 作用:RNN 接收视觉特征序列,并输出对每个时间步字符分类的预测概率分布。

2.3 转录层 (Transcription Layer) - 序列解码与训练

由于 RNN 的输出序列长度与真实的文本标签长度往往不一致(RNN 输出的时间步长通常远大于文本长度),因此需要一种机制来弥合这个差距。

  • 技术CTC (Connectionist Temporal Classification)
  • CTC 的核心:引入了一个特殊的标记 "blank" (空白符)
  • 解码(推理):CTC 将 RNN 输出的多余字符和空白符压缩,得到最终的文本标签。例如,RNN 输出 aa-b--c-cc- 表示空白符),CTC 解码后变为 abc
  • 训练(损失):CTC 损失函数可以直接计算 RNN 输出概率序列与真实标签(如 "abc")之间的差异,从而实现端到端训练,而不需要知道每个字符在图像中的具体位置。

3. 详细网络结构

本节将详细展示 CRNN 的结构参数(基于经典实现,通常基于 VGG-11 或 VGG-16 裁剪)。

输入图像假设:为了方便序列化,通常将输入图像的高固定(例如 H=32H=32),宽可以是不定长(例如 W=100W=100)。

3.1 CNN 特征提取层网络配置

经典的 CRNN CNN 结构如下表所示:

Layer 类型配置参数 (Kernel, Stride, Padding)输出特征图尺寸 (N, C, H, W)备注
Input-(1, 1, 32, 100)假设输入灰度图,32×10032 \times 100
Conv1k:3, s:1, p:1(1, 64, 32, 100)
ReLU + MaxPool1k:2, s:2, p:0(1, 64, 16, 50)高、宽减半
Conv2k:3, s:1, p:1(1, 128, 16, 50)
ReLU + MaxPool2k:2, s:2, p:0(1, 128, 8, 25)高、宽再次减半
Conv3k:3, s:1, p:1(1, 256, 8, 25)
Batch Normalization-(1, 256, 8, 25)引入 BN 加速收敛
Conv4k:3, s:1, p:1(1, 256, 8, 25)
ReLU + MaxPool3k:(2,2), s:(2,1), p:(0,1)(1, 256, 4, 26)关键点:高减半,宽不减(微增),保留水平分辨率
Conv5k:3, s:1, p:1(1, 512, 4, 26)
Batch Normalization-(1, 512, 4, 26)
ReLU + Conv6k:(2,2), s:(2,1), p:(0,0)(1, 512, 1, 25)关键点:高变为 1。输出宽为 Wout=25W_{out}=25
Conv7k:2, s:1, p:0(1, 512, 1, 24)经典 VGG 后接一个 2×22\times2 无 padding 卷积进一步融合
Feature Map-(1, 512, 1, 24)最终 CNN 特征图

总结:CNN 的最终输出是 (Batch, Channels, 1, Width_seq)。这个 Width_seq 就是序列的长。

3.2 特征序列化 (Map-to-Sequence)

这是将视觉特征转化为文本序列预测的关键步骤。

  1. Squeeze:去除高度为 1 的维度:(Batch, Channels, 1, Width_seq) \rightarrow (Batch, Channels, Width_seq)
  2. Permute:调整维度顺序以符合 RNN 的输入要求:(Batch, Channels, Width_seq) \rightarrow (Width_seq, Batch, Channels)
    • 在 PyTorch 中,RNN 的默认输入格式是 (Time_steps, Batch, Input_size)
    • 此时,每个时间步的输入特征大小(Input_size)等于 CNN 输出的通道数(例如 512)。

3.3 RNN 序列预测层配置

通常使用两层双向 LSTM (BiLSTM)。

  • RNN Input: (24, 1, 512) (假设宽序列为 24)
  • RNN Output: (24, 1, Hidden_size * 2) (因为是双向)
  • Linear Layer: 将 RNN 的输出映射到类别数。
    • Output: (24, 1, Number_of_Classes)。这里的 Number_of_Classes 必须包含真实的字符集(如 26 个字母+数字)加上一个 "blank" 标记。

4. PyTorch 代码实现

下面是完整的 CRNN 网络结构的 PyTorch 代码。它包含 CNN 模块、RNN 模块,以及最终将它们组合在一起的 CRNN 模块。

import torch
import torch.nn as nn

class VGG_FeatureExtractor(nn.Module):
    """
    经典的 CRNN 后端 CNN 特征提取器 (裁剪版 VGG)
    输入高度必须固定为 32。宽可以是可变的。
    """
    def __init__(self, input_channel=1, output_channel=512):
        super(VGG_FeatureExtractor, self).__init__()
        self.output_channel = output_channel
        
        # 定义核心卷积网络
        self.cnn = nn.Sequential(
            # Conv Block 1: 32xW -> 16xW/2
            nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: 64 x 16 x W/2

            # Conv Block 2: 16xW/2 -> 8xW/4
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: 128 x 8 x W/4

            # Conv Block 3: 8xW/4 -> 4xW/4
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256), # BN 在特征提取中很有效
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            # 关键:MaxPool 在高减半,但在宽方向stride为1,不减半
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1)), # Output: 256 x 4 x (W/4+2)

            # Conv Block 4: 4xW/4 -> 1xW/4
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            # 关键:MaxPool 高减半为1。
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 0)), # Output: 512 x 1 x Width_seq

            # Conv Block 5: 进一步融合特征
            nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0), # 高、宽都为2的卷积,进一步降维
            nn.ReLU(True)
            # 最终输出通道数:512
        )

    def forward(self, x):
        # 输入 x: (Batch, Input_channel, 32, W)
        conv = self.cnn(x)
        return conv

class BidirectionalLSTM(nn.Module):
    """
    单层双向 LSTM
    """
    def __init__(self, input_size, hidden_size, output_size):
        super(BidirectionalLSTM, self).__init__()
        # PyTorch 的 LSTM: 输入 (Time, Batch, Input_size)
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
        # 双向输出是隐层大小的两倍,需要通过线性层降维到 output_size
        self.embedding = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        # 输入 x: (Time, Batch, Input_size)
        recurrent, _ = self.rnn(x) # recurrent: (Time, Batch, Hidden*2)
        
        # 融合时间维和 Batch 维进行 Linear 处理
        T, B, H2 = recurrent.size()
        t_rec = recurrent.view(T * B, H2)
        
        output = self.embedding(t_rec) # (Time*Batch, Output_size)
        output = output.view(T, B, -1) # 复原 (Time, Batch, Output_size)
        
        return output

class CRNN(nn.Module):
    def __init__(self, img_h, nc, nclass, nh):
        """
        img_h: 输入图像高度 (应为32)
        nc: 输入图像通道数 (1 for gray, 3 for rgb)
        nclass: 字符类别数 (必须包含 blank 标记,通常 blank 的 index 是 0)
        nh: RNN 的隐层神经元数量
        """
        super(CRNN, self).__init__()
        assert img_h % 16 == 0, 'img_h has to be a multiple of 16'
        
        # 1. CNN 特征提取层
        self.cnn = VGG_FeatureExtractor(nc, 512)
        
        # 2. RNN 序列建模层 (两层 BiLSTM)
        # 第一层: 将 CNN 通道数 512 映射到 RNN 隐层nh
        self.rnn1 = BidirectionalLSTM(512, nh, nh)
        # 第二层: 将第一层输出(其实还是nh)进一步建模,并映射到最终的字符类别数 nclass
        self.rnn2 = BidirectionalLSTM(nh, nh, nclass)

    def forward(self, x):
        # 1. CNN 层特征提取
        # Input x: (Batch, nc, img_h, W)
        conv = self.cnn(x)
        
        # 2. 特征序列化 (Map-to-Sequence)
        # conv shape: (Batch, 512, 1, Width_seq)
        b, c, h, w = conv.size()
        assert h == 1, "The height of conv feature map must be 1"
        
        # Remove height dim: (B, C, 1, W_seq) -> (B, C, W_seq)
        conv = conv.squeeze(2)
        # Permute for RNN: (B, C, W_seq) -> (W_seq, B, C)
        conv = conv.permute(2, 0, 1) # (Time_steps, Batch, Input_size)
        
        # 3. RNN 层序列预测
        rnn_out = self.rnn1(conv)
        rnn_out = self.rnn2(rnn_out)
        
        # 最终输出形状: (Time_steps, Batch, nclass)
        # 这也是 PyTorch nn.CTCLoss 需要的标准输入格式
        return rnn_out

# --- 测试模型输入输出 ---
if __name__ == "__main__":
    # 参数设置
    batch_size = 1
    input_channels = 1 # 灰度图
    img_h = 32
    img_w = 100        # 可变宽度
    nh = 256           # RNN 隐层神经元数
    # nclass 包括:空白符 + 字符集 (如 'a'-'z', '0'-'9')
    # 假设 blank 是 0, 字母是 1-26, 共 27 类
    nclass = 27 
    
    # 实例化模型
    model = CRNN(img_h, input_channels, nclass, nh)
    
    # 检查网络结构
    # print(model)

    # 模拟输入:1张灰度图, 32x100
    dummy_input = torch.randn(batch_size, input_channels, img_h, img_w)
    
    # 模型前向传播
    output = model(dummy_input)
    
    print(f"\n模型测试结果:")
    print(f"输入形状 (Batch, C, H, W): {dummy_input.shape}")
    print(f"输出形状 (Time_steps, Batch, nclass): {output.shape}")
    
    # 根据 32x100 的输入,CNN 最终输出宽序列大约是 24-26。
    # 输出的时间步长应该大约是这个数。
    time_steps, _, _ = output.shape
    assert time_steps > 0, "Error: Time_steps is 0"
    print("模型输出测试成功!")

5. 教程补充:CRNN 的训练与推理逻辑

理解代码仅仅是第一步,要让 CRNN 工作,你还需要理解它是如何训练和推理的。

5.1 训练逻辑 (Train) - 核心是 CTC Loss

在 PyTorch 中使用 CRNN 训练时,最重要的是正确设置 nn.CTCLoss

import torch.nn.functional as F

# 1. 实例化损失函数
# blank=0 表示我们在 nclass 中blank标记的索引是0
criterion = nn.CTCLoss(blank=0, reduction='mean')

# 2. 模型前向传播
# 假设 input 数据大小是 (Batch, 1, 32, W)
model_output = model(images) # Shape: (Time_steps, B, nclass)

# 3. 计算 Log-Probabilities
# CTC 需要输入是对数概率
log_probs = F.log_softmax(model_output, dim=2)

# 4. 准备 CTC Loss 需要的参数
# T 是序列的长度 (Time_steps),也就是 RNN 输出的第一维
input_lengths = torch.full((batch_size,), time_steps, dtype=torch.long)
# target_lengths 是每张图中真实的标签长度 (例如 "hello" 是 5)
target_lengths = torch.tensor([len(t) for t in labels_encoded], dtype=torch.long)
# targets 是将真实文本转为数字的序列 (例如 [8, 5, 12, 12, 15]),并平铺成一维

# 5. 计算损失
loss = criterion(log_probs, targets, input_lengths, target_lengths)

# 6. 反向传播与更新
loss.backward()
optimizer.step()

5.2 推理逻辑 (Inference) - 解码过程

推理过程不需要计算 Loss,只需要将 RNN 输出的类别概率序列转化为最终的文本。

最简单的解码方法是贪婪解码 (Greedy Decoding)

  1. Get Predictions:对 RNN 输出的每个时间步,取概率最大的那个字符索引:preds = torch.argmax(output, dim=2)。这会得到一个像 [8, 8, 0, 0, 5, 0, 12, 12, 0, 12, 0, 15] 的数字序列。
  2. CTC Decode (压缩)
    • 步骤一:合并连续重复的非 blank 字符:[8, 0, 0, 5, 0, 12, 0, 12, 0, 15]
    • 步骤二:去除所有空白符(blank,在此为0):[8, 5, 12, 12, 15]
  3. 转文字:利用索引表将数字转回字符:hello

通过这篇教程,你应该对 CRNN 的前世今生、网络结构有了深刻理解,并且拥有了核心的 PyTorch 实现代码。CRNN 是 OCR 学习者的必经之路,希望本教程对你有所帮助!