端到端不定长文字识别:CRNN 模型详解与 PyTorch 实现
1. 前言
在 OCR(光学字符识别)领域,识别图像中的文字序列一直是一个核心挑战。传统的 OCR 方法通常分为两个阶段:先检测单个字符,然后再进行分类。这种方法需要对每个字符进行精细的切分,不仅繁琐,而且对于粘连字符或模糊文本的识别效果很差。
CRNN (Convolutional Recurrent Neural Network) 的提出改变了这一局面。它是由 Baoguang Shi 等人在 2015 年提出的,其核心思想是将卷积神经网络(CNN)提取的特征序列化,然后利用循环神经网络(RNN)处理序列信息,最后结合 CTC(Connectionist Temporal Classification)损失函数实现不定长序列的端到端训练和识别。
CRNN 的主要优点:
- 端到端训练:不需要对字符进行单独切分和标注。
- 处理不定长序列:无论是 3 个字符还是 10 个字符的文本,模型都能处理。
- 结合上下文信息:RNN 能够捕捉字符之间的序列依赖关系,提高识别准确率(例如,识别出 "ca" 后,更有可能识别出 "t" 而不是 "f")。
- 模型轻量高效:相比于基于 Attention 的模型,CRNN 训练和推理速度更快。
2. 网络概述
CRNN 的架构非常清晰,它结合了三种不同的神经网络技术,自底向上分为三层:
2.1 卷积层 (Convolutional Layers) - 特征提取
CRNN 的底部是一个标准的 CNN(通常使用标准 VGG 或其变体)。
- 输入:灰度或 RGB 图像。
- 输出:一个特征图(Feature Map)。
- 作用:CNN 负责从输入图像中提取高维视觉特征。
2.2 循环层 (Recurrent Layers) - 序列建模
这是 CRNN 的核心创新点。模型将 CNN 输出的特征图转化为一个特征向量序列。
- 操作:特征图的每一列(或者特定宽度的区域)被视为序列中的一个“时间步(Time Step)”。
- 网络:通常使用双向 LSTM(Bidirectional LSTM, BiLSTM)。BiLSTM 能够同时捕获序列的前向和后向上下文信息。
- 作用:RNN 接收视觉特征序列,并输出对每个时间步字符分类的预测概率分布。
2.3 转录层 (Transcription Layer) - 序列解码与训练
由于 RNN 的输出序列长度与真实的文本标签长度往往不一致(RNN 输出的时间步长通常远大于文本长度),因此需要一种机制来弥合这个差距。
- 技术:CTC (Connectionist Temporal Classification)。
- CTC 的核心:引入了一个特殊的标记 "blank" (空白符)。
- 解码(推理):CTC 将 RNN 输出的多余字符和空白符压缩,得到最终的文本标签。例如,RNN 输出
aa-b--c-cc(-表示空白符),CTC 解码后变为abc。 - 训练(损失):CTC 损失函数可以直接计算 RNN 输出概率序列与真实标签(如 "abc")之间的差异,从而实现端到端训练,而不需要知道每个字符在图像中的具体位置。
3. 详细网络结构
本节将详细展示 CRNN 的结构参数(基于经典实现,通常基于 VGG-11 或 VGG-16 裁剪)。
输入图像假设:为了方便序列化,通常将输入图像的高固定(例如 ),宽可以是不定长(例如 )。
3.1 CNN 特征提取层网络配置
经典的 CRNN CNN 结构如下表所示:
总结:CNN 的最终输出是 (Batch, Channels, 1, Width_seq)。这个 Width_seq 就是序列的长。
3.2 特征序列化 (Map-to-Sequence)
这是将视觉特征转化为文本序列预测的关键步骤。
- Squeeze:去除高度为 1 的维度:
(Batch, Channels, 1, Width_seq)(Batch, Channels, Width_seq)。 - Permute:调整维度顺序以符合 RNN 的输入要求:
(Batch, Channels, Width_seq)(Width_seq, Batch, Channels)。- 在 PyTorch 中,RNN 的默认输入格式是
(Time_steps, Batch, Input_size)。 - 此时,每个时间步的输入特征大小(
Input_size)等于 CNN 输出的通道数(例如 512)。
- 在 PyTorch 中,RNN 的默认输入格式是
3.3 RNN 序列预测层配置
通常使用两层双向 LSTM (BiLSTM)。
- RNN Input:
(24, 1, 512)(假设宽序列为 24) - RNN Output:
(24, 1, Hidden_size * 2)(因为是双向) - Linear Layer: 将 RNN 的输出映射到类别数。
- Output:
(24, 1, Number_of_Classes)。这里的Number_of_Classes必须包含真实的字符集(如 26 个字母+数字)加上一个 "blank" 标记。
- Output:
4. PyTorch 代码实现
下面是完整的 CRNN 网络结构的 PyTorch 代码。它包含 CNN 模块、RNN 模块,以及最终将它们组合在一起的 CRNN 模块。
5. 教程补充:CRNN 的训练与推理逻辑
理解代码仅仅是第一步,要让 CRNN 工作,你还需要理解它是如何训练和推理的。
5.1 训练逻辑 (Train) - 核心是 CTC Loss
在 PyTorch 中使用 CRNN 训练时,最重要的是正确设置 nn.CTCLoss。
5.2 推理逻辑 (Inference) - 解码过程
推理过程不需要计算 Loss,只需要将 RNN 输出的类别概率序列转化为最终的文本。
最简单的解码方法是贪婪解码 (Greedy Decoding):
- Get Predictions:对 RNN 输出的每个时间步,取概率最大的那个字符索引:
preds = torch.argmax(output, dim=2)。这会得到一个像[8, 8, 0, 0, 5, 0, 12, 12, 0, 12, 0, 15]的数字序列。 - CTC Decode (压缩):
- 步骤一:合并连续重复的非 blank 字符:
[8, 0, 0, 5, 0, 12, 0, 12, 0, 15]。 - 步骤二:去除所有空白符(blank,在此为0):
[8, 5, 12, 12, 15]。
- 步骤一:合并连续重复的非 blank 字符:
- 转文字:利用索引表将数字转回字符:
hello。
通过这篇教程,你应该对 CRNN 的前世今生、网络结构有了深刻理解,并且拥有了核心的 PyTorch 实现代码。CRNN 是 OCR 学习者的必经之路,希望本教程对你有所帮助!

