Vision Transformer (ViT)详解:从图像到序列的视觉革命

引言

2020年,Google甩出一篇《An Image is Worth 16x16 Words》,直接动摇了CNN在计算机视觉领域的长期霸主地位。这篇文章提出的Vision Transformer (ViT),第一次让NLP界的大杀器——Transformer——在图像分类任务上稳稳站住了脚跟,甚至在大规模数据上反超了ResNet等经典卷积网络。

ViT的核心理念只有一句话,但足够震撼:

像处理自然语言一样处理图像——把图片切成一个个“视觉词”,再用自注意力机制一次性捕捉全局关系。

这种思路到底有多颠覆?我们往下看。


1. ViT的诞生:CNN的“天花板”与破局

1.1 为什么CNN不够用了?

ResNet、EfficientNet这些卷积神经网络虽然把局部特征提取做到了极致,但它们天生有两个绕不开的先验设定

归纳偏置带来的优势隐含的局限性
局部感受野学习边缘、纹理等基础特征快得堆叠几十层才能勉强“看”到全局(比如狗的身体轮廓)
平移不变性对物体位置变化不那么敏感缺乏对位置关系的显式建模(“猫耳朵在猫眼睛上面”)
静态权重推理效率高所有图像用同一套卷积核,无法动态关注“当前图像的关键区域”

简单来说,CNN就像一个只看细节的画家,画完所有局部之后,还需要花很多力气才能拼出完整的画面。而当时的研究人员开始思考:有没有一种方法,能让模型从一开始就看到全局

1.2 ViT的破局思路

ViT的做法非常干脆:它直接把CNN“局部优先”的设计推倒重来,换成了Transformer的“全局优先”范式。

ViT带来的几个关键升级:

  1. 全局感知一步到位:第一层自注意力就能让左上角的像素和右下角的像素直接“对话”
  2. 注意力权重动态生成:根据图像内容,自动调整不同区域的重要性,不再是死板的静态卷积核
  3. 可扩展性极强:模型越大、数据越多,性能提升越明显——Scaling Law在视觉领域也生效了

2. 极简架构拆解:ViT到底做了哪几件事?

ViT的整体结构几乎完全复用了NLP Transformer的编码器,唯一的变化就是把“文本序列”换成了“视觉序列”。整个流程可以概括为4步:

graph TD
    A[输入图像<br/>224×224×3] --> B[切割为不重叠的块<br/>16×16×3 × 196块]
    B --> C[线性投影<br/>每块→768维向量]
    C --> D[拼接CLS Token + 位置编码<br/>197×768]
    D --> E[堆叠N层Transformer编码器]
    E --> F[取CLS Token输出<br/>分类头→1000类]

关键组件速览

  1. Patch Embedding(图像转序列的核心):用卷积或展平+线性层,把图像块转成固定维度的向量,相当于把图像翻译成Transformer认识的“单词”
  2. CLS Token:在输入序列最前面拼接一个可学习的“全局汇总向量”,最终用它来做分类——这个巧思直接借鉴了BERT
  3. 可学习位置编码:把图像块的“位置信息”注入向量(因为Transformer本身对位置无感,必须告诉它哪块在哪)
  4. Transformer编码器:多头自注意力 + 全连接前馈网络 + 残差连接 + 层归一化,经典配方

3. PyTorch极简实现:从零搭ViT-B/16

下面我们用PyTorch实现ViT最经典的变体——ViT-B/16(Base规模,16×16的块大小)。代码力求清晰,关键步骤都加了注释。

3.1 第一步:把图像切成“视觉词汇”

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    图像分块嵌入:用卷积高效实现“分块+线性投影”
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2  # 14×14=196个块
        
        # 卷积核大小=stride=patch_size,一步到位分块+投影
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # 输入: (B, C, H, W) → 输出: (B, n_patches, embed_dim)
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

# 测试一下
if __name__ == "__main__":
    patch_embed = PatchEmbedding()
    dummy_img = torch.randn(2, 3, 224, 224)  # 2张RGB图
    print(f"Patch嵌入后形状: {patch_embed(dummy_img).shape}")  # 输出: torch.Size([2, 196, 768])

3.2 第二步:搭Transformer编码器块

class MultiHeadAttention(nn.Module):
    """
    多头自注意力:简化版实现
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        # 1. 计算QKV并拆分多头
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        # 2. 缩放点积注意力
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.dropout(attn.softmax(dim=-1))
        
        # 3. 拼接多头并投影
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class TransformerBlock(nn.Module):
    """
    Transformer编码器块:Pre-Norm结构
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # MLP:中间层维度是embed_dim的4倍
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # 残差连接1
        return x + self.mlp(self.norm2(x))  # 残差连接2

3.3 第三步:组装完整ViT模型

class VisionTransformer(nn.Module):
    """
    完整ViT-B/16模型
    """
    def __init__(self, img_size=224, n_classes=1000, depth=12):
        super().__init__()
        self.patch_embed = PatchEmbedding()
        
        # CLS Token + 可学习位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
        self.pos_embed = nn.Parameter(torch.zeros(1, 196 + 1, 768))  # +1是CLS Token
        self.pos_drop = nn.Dropout(0.1)
        
        # 堆叠12层Transformer块
        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(depth)])
        
        # 分类头
        self.norm = nn.LayerNorm(768)
        self.head = nn.Linear(768, n_classes)

    def forward(self, x):
        B = x.shape[0]
        
        # 1. 图像转Patch
        x = self.patch_embed(x)
        
        # 2. 拼接CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 3. 加位置编码
        x = self.pos_drop(x + self.pos_embed)
        
        # 4. 通过Transformer编码器
        for block in self.blocks:
            x = block(x)
        
        # 5. 取CLS Token输出分类
        return self.head(self.norm(x)[:, 0])

# 测试完整模型
if __name__ == "__main__":
    vit = VisionTransformer()
    dummy_img = torch.randn(4, 3, 224, 224)
    print(f"ViT输出形状: {vit(dummy_img).shape}")  # 输出: torch.Size([4, 1000])

4. 避坑指南:用好ViT的3个关键

ViT没有CNN那种“局部感受野”的先验知识,在**中小规模数据集(少于100万张)上直接训练,效果一定不如CNN**。 ✅ 正确做法:使用大规模预训练权重(比如ImageNet-21k上预训练,或基于MAE的自监督权重),然后在你的任务上微调。

4.1 超参数选择

超参数推荐配置说明
优化器AdamW (lr=1e-3, weight_decay=0.05)必须用带权重衰减的AdamW
学习率调度Warmup (10 epochs) + Cosine Decay训练初期用小学习率,避免CLS Token震荡
数据增强RandAugment + CutMix + MixUp高级增强是ViT在小数据集上收敛的关键
Batch Size越大越好(至少256,推荐1024+)大Batch能稳定注意力权重的训练

4.2 什么时候用ViT,什么时候用CNN?

def choose_between_vit_cnn(data_size, is_speed_critical):
    if data_size < 100_000:
        return "首选CNN(ResNet/EfficientNet),可考虑用DeiT蒸馏"
    elif data_size > 1_000_000 and not is_speed_critical:
        return "首选ViT(或Swin Transformer),用大规模预训练权重微调"
    else:
        return "折中方案:MobileViT(移动端)/CoAtNet(混合架构)"

5. 总结

Vision Transformer用“序列建模”的统一范式,为计算机视觉打开了一扇新的大门。虽然它确实存在“吃数据”、“计算量偏大”的短板,但在大规模预训练 + 下游微调的模式下,ViT已经成为图像分类、目标检测、语义分割等任务的主流选择之一。

如果你打算进一步探索ViT的演进家族,建议按这个顺序往下读:

  1. DeiT:解决ViT在小数据集上的训练难问题,用蒸馏方法提升性能
  2. Swin Transformer:引入层次化结构和滑动窗口,更适合检测、分割任务
  3. MAE:自监督预训练的代表作,大幅降低ViT对标注数据的需求

相关教程

1. 用`timm`库加载预训练的ViT-B/16,在CIFAR-10上微调,感受其效果 2. 尝试可视化ViT的注意力热力图,直观理解模型到底在“看”图片的哪些部分

🔗 扩展阅读