ViT(Vision Transformer)

1. 前言:为什么要抛弃卷积?

虽然 CNN 很强,但它有一个天然的局限性:归纳偏置(Inductive Bias)

  • 局部性:卷积核每次只看几个像素,很难直接捕捉图像左上角和右下角之间的联系(即全局信息)。
  • 静态权重:卷积核训练好后,对所有图片的过滤方式是一样的。

ViT 的核心思想:利用 自注意力机制(Self-Attention)。它让图像中的每一个区域都去跟其他所有区域“打招呼”,从而自动学习哪些部分是相关的。


2. 网络概述:ViT 的四大步

ViT 并不直接把像素丢进 Transformer,而是经过了精妙的转化:

  1. 图像分块 (Patch Embedding):将一张 224×224224 \times 224 的图片切成 16×1616 \times 16 个小方块(Patches)。每个 Patch 就像是一个“单词”。
  2. 线性投影:把每个 Patch 拉直并映射到一个固定维度的向量。
  3. 位置编码 (Position Embedding):因为 Transformer 无法感知顺序,必须给每个 Patch 加上一个“坐标”,告诉模型谁在谁旁边。
  4. CLS Token:专门在序列开头加一个额外的向量,用来汇总整张图的特征,最后用于分类。

3. 详细网络结构:PyTorch 实现

我们手动实现一个简易版的 ViT 核心流程。

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """将图像切块并映射为向量"""
    def __init__(self, img_size=224, patch_size=16, in_ch=3, embed_dim=768):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        # 使用卷积来实现切块和投影(巧妙的技巧:kernel=stride=patch_size)
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, 3, 224, 224) -> (B, 768, 14, 14)
        x = self.proj(x)
        x = x.flatten(2) # (B, 768, 196)
        x = x.transpose(1, 2) # (B, 196, 768)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        self.patch_embed = PatchEmbedding()
        
        # 可学习的 CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + 196, 768))
        
        # Transformer 编码器层
        encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=12)
        
        # 分类头
        self.mlp_head = nn.Linear(768, n_classes)

    def forward(self, x):
        # 1. 切块投影
        x = self.patch_embed(x)
        
        # 2. 拼接 CLS Token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # 3. 叠加位置编码
        x = x + self.pos_embed
        
        # 4. 进入 Transformer 层
        x = self.transformer(x)
        
        # 5. 取出 CLS Token 的输出进行分类
        return self.mlp_head(x[:, 0])

# 测试
model = VisionTransformer()
img = torch.randn(1, 3, 224, 224)
logits = model(img)
print(f"输出维度: {logits.shape}") # (1, 10)

4. ViT vs CNN:该选谁?

在你的 daomanpy.com 教程中,可以为学员总结以下对比:

特性CNN (如 ResNet)ViT (Vision Transformer)
数据量需求中小规模数据表现良好极度依赖大数据 (如 ImageNet-21k)
全局感知需要通过深层堆叠获得第一层就能看到全局
训练难度容易收敛训练较慢,对超参数和优化器敏感
可解释性关注局部边缘/纹理关注图像各部分之间的逻辑关联

5. 总结与应用建议

ViT 的变体:

  • Swin Transformer:引入了层级结构和移动窗口(Shifted Windows),大大降低了计算量,目前在检测和分割领域非常火。
  • MAE (Masked Autoencoders):何恺明提出的自监督学习方法,让 ViT 学习如何补全被遮挡的图片,是目前的工业界主流。

你的 daomanpy.com 教程建议: 对于初学者,建议先用 CNN (YOLO/DBNet) 解决实际工程问题。当你的数据集达到十万级以上,且算力充足(比如有 3090/4090 显卡)时,再考虑引入 ViT 来刷高精度。

你已经跑通了这么多经典模型,下一节要不要我带你看看如何用 TensorRT 优化这些模型,让它们在你的 D 盘环境下跑出飞一般的速度?