第一课:猫狗图片分类 —— 从底层到前沿

猫狗分类是计算机视觉的“Hello World”。本课将带你通过两种截然不同的方式完成任务:

  1. 致敬经典:使用 PyTorch 手写一个卷积神经网络 (CNN)。
  2. 拥抱未来:利用 Vision Transformer (ViT) 预训练模型实现高性能识别。

1. 核心目标

  • 理解 CNN 的局部特征提取原理。
  • 掌握 Transformer 的图像分块 (Patchifying) 与全局建模。
  • 学会使用 timm 库快速调用 2026 年顶尖的深度学习模型。

2. 方案一:知其然 —— 手写轻量级 CNN

CNN 是视觉识别的基石。通过卷积层提取边缘,池化层降低维度,全连接层输出结果。

2.1 PyTorch 模型定义

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        # 卷积层:提取图像特征(边缘、纹理)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        # 池化层:压缩尺寸,减少计算量
        self.pool = nn.MaxPool2d(2, 2)
        
        # 全连接层:假设输入 224x224,经过两次池化变为 56x56
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 224 -> 112
        x = self.pool(F.relu(self.conv2(x))) # 112 -> 56
        x = x.view(x.size(0), -1)           # 展平
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_model = SimpleCNN(num_classes=2).to(device)
print("手写 CNN 架构已就绪")

3. 方案二:知其所以然 —— 预训练 Vision Transformer

在 2026 年,我们很少从零训练模型。利用在超大规模数据集上训练过的 ViT,可以实现“降维打击”。

3.1 核心流程:图像变序列

ViT 不使用卷积,而是将图像切成 16x16 的小块 (Patches),像处理文字一样处理图像块之间的关联。

3.2 快速推理实现

import timm

# 加载 2026 主流的 ViT 微型版本,兼顾速度与精度
vit_model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=2)
vit_model = vit_model.to(device)
vit_model.eval()

# 使用 timm 自带的配置进行预处理
config = timm.data.resolve_model_data_config(vit_model)
transform = timm.data.create_transform(**config, is_training=False)

print("预训练 ViT 模型加载成功")

4. 两种方案深度对比

维度手写 CNN预训练 ViT (Transformer)
理解难度直观,适合学习底层算子抽象,需理解自注意力机制
数据需求需要海量标注数据从头训练少量数据进行微调 (Fine-tune) 即可
性能表现适合简单场景(如验证码)适合复杂场景(如自然场景识别)
部署建议极易导出 ONNX,兼容性极强推理开销略大,建议配合推理加速引擎

5. 开发者实战:编写通用预测函数

from PIL import Image

def predict(img_path, model, model_type="vit"):
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        pred = output.argmax(dim=1).item()
        
    return "狗" if pred == 1 else "猫"

# 示例调用
# print(f"CNN 结果: {predict('test.jpg', cnn_model)}")
# print(f"ViT 结果: {predict('test.jpg', vit_model)}")
`