Vision Transformer (ViT) 详解:从理论到实践的完整指南

引言

如果你在2019年问计算机视觉(CV)研究者,纯Transformer能不能替代CNN做图像分类,大概率会得到否定答案——归纳偏置的缺失、长序列自注意力的计算复杂度,都是看起来绕不开的坎。但2020年Google Brain和DeepMind联手抛出的《An Image is Worth 16x16 Words》彻底打破了这个刻板印象:在14M+标注图像的大规模预训练下,ViT首次让纯注意力模型超越了SOTA CNN,也标志着CV领域正式进入「注意力时代」。


1. ViT的背景与动机

1.1 传统CNN的局限性

CNN在CV领域统治近10年,但天生带有3个核心瓶颈:

  • 硬编码归纳偏置:局部感受野、平移不变性虽在小/中数据集上稳定,但限制了对图像「全局长距离依赖」(如人脸的眼睛和嘴巴关系、自然图像的上下文)的建模。
  • 计算依赖堆叠深度:要覆盖全图224x224分辨率的信息,CNN需堆叠十几甚至几十层卷积,梯度传播效率低。
  • 分辨率扩展成本高:深层CNN的感受野增长慢于计算量(感受野≈层数×核步长),处理448x448图像时计算复杂度翻4倍以上。

1.2 Transformer的迁移优势

Transformer在NLP领域的成功,为CV提供了一套通用序列建模工具

  • 天然全局感受野:任意两个位置的token(相当于NLP的词)都能直接交互,无需堆叠层。
  • 可扩展且通用:模型容量可通过「层数、embed_dim、注意力头数」轻松调整,同一架构稍改就能适配分类、检测、分割等任务。
  • 并行计算友好:不像RNN/LSTM那样顺序处理,所有token的自注意力可同时计算。

2. ViT的核心架构

2.1 整体结构速览

ViT的设计理念非常直接:把图像「伪装」成NLP的句子序列,丢进标准Transformer编码器就行。关键组件只有6个:

graph LR
A[输入图像<br/>224x224x3] --> B[Patch Embedding<br/>拆成16x16的patch<br/>每个投影到768维]
B --> C[拼接Class Token<br/>用于全局分类]
C --> D[添加可学习位置编码<br/>保留空间位置]
D --> E[标准Transformer编码器<br/>12层×12头]
E --> F[取Class Token输出<br/>过MLP Head做分类]

2.2 完整PyTorch实现

我们用最简洁的PyTorch原生代码实现ViT-Base(与原论文一致的基础版本),每个组件都加了清晰注释:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ----------------------------
# 1. 核心基础模块
# ----------------------------
class PatchEmbedding(nn.Module):
    """
    将图像分割成固定大小的patch并做线性嵌入
    等价于先用16x16卷积步长16,再展平转置
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # 224/16=14 → 14²=196个patch
        
        # 用卷积层实现高效的patch嵌入(避免手动拆图循环)
        self.conv_proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
    def forward(self, x):
        # x: (batch_size, 3, 224, 224)
        x = self.conv_proj(x)  # → (batch_size, 768, 14, 14)
        x = x.flatten(2)        # → (batch_size, 768, 196)
        x = x.transpose(1, 2)   # → (batch_size, 196, 768)
        return x

class MultiHeadSelfAttention(nn.Module):
    """
    标准多头自注意力(MHSA)
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        assert self.head_dim * n_heads == embed_dim, "embed_dim必须能被n_heads整除"
        
        # 一次性生成Q、K、V的线性层(比分开写高效)
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        
        # 1. 生成Q、K、V,并拆成多头
        qkv = self.qkv_proj(x)  # → (batch_size, seq_len, 3*embed_dim)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # → (3, batch_size, n_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 2. 计算注意力权重 + 缩放 + dropout
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)
        
        # 3. 加权求和V + 拼接多头 + 输出投影
        context = torch.matmul(attn_probs, v)  # → (batch_size, n_heads, seq_len, head_dim)
        context = context.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        return self.out_proj(context)

class MLPBlock(nn.Module):
    """
    ViT中的MLP块:GELU激活 + 隐藏层扩张4倍
    """
    def __init__(self, embed_dim=768, mlp_dim=3072, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_dim)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(mlp_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderLayer(nn.Module):
    """
    标准Transformer编码器层:Pre-LN架构(原论文采用)
    Pre-LN:先做LayerNorm,再过残差,更稳定
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_dim=3072, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.mhsa = MultiHeadSelfAttention(embed_dim, n_heads, dropout)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = MLPBlock(embed_dim, mlp_dim, dropout)
        
    def forward(self, x):
        # Pre-LN + 残差
        x = x + self.mhsa(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# ----------------------------
# 2. ViT完整模型
# ----------------------------
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, n_heads=12, mlp_dim=3072, dropout=0.1):
        super().__init__()
        
        # 1. Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # 2. Class Token:可学习的全局分类向量
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # 3. 可学习位置编码:无硬编码正弦,让模型自己学
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # 4. 堆叠Transformer编码器
        self.encoder = nn.Sequential(*[
            TransformerEncoderLayer(embed_dim, n_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])
        
        # 5. 分类头
        self.ln_head = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # 6. 初始化权重(原论文用截断正态分布)
        self._init_weights()
        
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # 步骤1:Patch嵌入
        x = self.patch_embed(x)  # → (batch_size, 196, 768)
        
        # 步骤2:拼接Class Token(每个样本补一个)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # → (batch_size, 1, 768)
        x = torch.cat([cls_tokens, x], dim=1)  # → (batch_size, 197, 768)
        
        # 步骤3:加位置编码 + dropout
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # 步骤4:过编码器
        x = self.encoder(x)
        
        # 步骤5:取Class Token输出分类
        x = self.ln_head(x[:, 0])  # 仅取第一个位置(Class Token)的输出
        x = self.head(x)
        
        return x

# ----------------------------
# 3. 测试模型
# ----------------------------
if __name__ == "__main__":
    # 创建ViT-Base实例(对应原论文ViT-B/16)
    vit_base = VisionTransformer(
        img_size=224, patch_size=16,
        embed_dim=768, depth=12, n_heads=12, mlp_dim=3072,
        num_classes=1000
    )
    
    # 统计参数量
    total_params = sum(p.numel() for p in vit_base.parameters())
    print(f"ViT-B/16 总参数量: {total_params / 1e6:.1f}M")  # 约86M
    
    # 测试前向传播
    dummy_img = torch.randn(1, 3, 224, 224)
    output = vit_base(dummy_img)
    print(f"输入形状: {dummy_img.shape}")
    print(f"输出形状: {output.shape}")  # 应该是(1, 1000)

3. ViT vs CNN:直观对比

维度ViT-B/16ResNet-50
总参数量~86M~25M
全局感受野第1层就有需30层以上
归纳偏置几乎无强(局部性、平移不变)
小数据集表现(如CIFAR-10)易过拟合(需数据增强/蒸馏)稳定
大规模数据集表现(如ImageNet-21K预训练)超过ResNet-50的SOTA达到瓶颈
可解释性可通过注意力权重可视化「关注区域」只能看特征图/梯度CAM

4. 训练与使用的小建议

4.1 核心训练技巧

  1. 必须用大数据预训练:至少1M+图像,最好是ImageNet-21K(14M)或JFT-300M(原论文用的);
  2. 数据增强要狠:用RandAugment、MixUp、CutMix代替简单的随机裁剪翻转;
  3. 中小数据集怎么办? 用知识蒸馏(如DeiT用RegNetY-16GF作为教师模型);
  4. 学习率设置:预训练时用更大的学习率(如ViT-B用3e-3)+ 余弦退火+warmup;

4.2 快速上手预训练模型

不用自己从零训练!直接用torchvisiontimm库的预训练权重:

import torchvision.models as models

# 加载ImageNet-1K预训练的ViT-B/16
vit_b = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
print(vit_b.eval())

# 假设要改分类头做自己的任务(比如10类)
vit_b.heads = nn.Linear(768, 10)

总结

ViT不是「推翻CNN」,而是给CV领域提供了一套更通用、可扩展的建模范式。现在的主流视觉模型(如DETR、Mask2Former、CLIP、SAM)几乎都以ViT或其变体(Swin、DeiT)为基础。

如果想深入学习,建议先动手跑通上面的PyTorch代码,再试试用timm库的预训练模型做微调,最后阅读原论文和Swin Transformer的改进思路。


相关教程

可以用CIFAR-100数据集试试「DeiT式」微调:加载timm的`deit_base_patch16_224`预训练权重,替换分类头,用RandAugment+MixUp训练,效果会比从零训好很多!

🔗 扩展阅读