实战:基于 PyTorch 的 CNN 验证码识别教程

本教程将带你实现一个能识别“数字+字母”组合的深度学习模型。不同于调用现成库,我们将从底层原理出发,训练一个专属的识别引擎。
项目地址:https://github.com/MgArcher/VerificationCodeRecognition

1. 项目核心流程图

  1. 数据生成:使用 Python 库批量生成带有干扰的验证码图片。
  2. 构造网络:搭建多层 CNN 提取视觉特征。
  3. 模型训练:通过损失函数不断优化神经元参数。
  4. 验证与部署:将训练好的 .pth 模型应用到爬虫项目中。

2. 环境准备 (2026 兼容版)

在开始前,请确保安装了核心依赖:

pip install torch torchvision
pip install pillow numpy
pip install captcha  # 用于生成验证码数据集

3. 核心板块解析

3.1 数据集准备 (Data Generation)

在该仓库中,数据是动态生成的。验证码识别的难点在于:一个标签(Label)对应多个字符

  • One-Hot 编码:如果验证码是 4 位,每位包含 26 个字母+10 个数字,那么模型输出的向量长度应为 4×364 \times 36

3.2 CNN 模型架构

这是仓库中最关键的代码部分(简化逻辑):

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 第一层卷积:提取基础边缘特征
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        
        # ... 经过多层卷积缩放尺寸 ...

        # 全连接层:将二维特征图转化为一维分类概率
        self.fc = nn.Linear(32 * 15 * 50, 4 * 36) 

    def forward(self, x):
        out = self.layer1(x)
        out = out.view(out.size(0), -1) # 展平
        out = self.fc(out)
        return out

4. 动手实战步奏

第一步:克隆并初始化

git clone https://github.com/MgArcher/VerificationCodeRecognition.git
cd VerificationCodeRecognition

第二步:生成训练集

运行仓库中的数据生成脚本,生成数万张验证码图片存入 dataset/train。数据量越大,模型的抗干扰能力越强。

第三步:启动训练

执行训练脚本。注意观察 Loss (损失值)

  • 如果 Loss 持续下降并趋于平稳,说明模型正在“学习”字符特征。
  • 建议:在 Windows 环境下若无显卡,可修改 device = torch.device("cpu")

第四步:模型推理 (Inference)

训练完成后会得到 model.pth。在你的爬虫代码中如下调用:

model.load_state_dict(torch.load('model.pth'))
model.eval()
# 将爬虫抓到的图片转化为 Tensor 后传入
output = model(img_tensor)
# 解析输出向量得到最终字符串

5. 为什么选这个库做教程? (技术深度)

  1. 理解多标签分类:不同于常规的“一张图一个类别”,验证码识别是“一张图多个类别”。通过这个项目,你可以掌握如何处理联合概率分布。
  2. 端到端识别 (End-to-End):传统的 OCR 需要先做“字符切割”,而这个库展示了如何直接输入整张图,输出整串字符,这是目前工业界最主流的做法。
  3. 定制化能力:如果爬虫遇到的验证码背景非常特殊(如 3D 扭曲),你可以通过修改生成代码模拟这种背景,训练出比通用库更准的模型。

💼 开发者避坑指南

  • 灰度化:在输入模型前,务必将图片转为单通道灰度图,这能减少 2/3 的计算量且不丢失关键特征。
  • 字符集不匹配:如果你的目标网站有大小写区分,记得在 constants.py 中修改字符集长度,否则模型会识别不出大写字母。