Vision Transformer (ViT) 详解:图像切片、Patch Embedding

📂 所属阶段:第四阶段 — 视觉新范式(Transformer 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer


1. ViT 核心思想

ViT = Vision Transformer

创新:将图像当作文本处理

步骤:
1. 将图像分成 16×16 的 Patches
2. 线性投影得到 Patch Embeddings
3. 加上位置编码
4. 输入 Transformer

2. ViT 实现

import torch
import torch.nn as nn
from einops import rearrange

class ViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12):
        super().__init__()
        
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        
        self.patch_embedding = nn.Linear(patch_dim, dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4),
            num_layers=depth
        )
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, x):
        # x: (B, 3, 224, 224)
        
        # 分割成 Patches
        patches = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
        
        # Patch Embedding
        x = self.patch_embedding(patches)
        
        # 添加 CLS token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # 位置编码
        x = x + self.pos_embedding
        
        # Transformer
        x = self.transformer(x)
        
        # 取 CLS token
        x = x[:, 0]
        
        # 分类
        x = self.mlp_head(x)
        return x

3. 使用预训练 ViT

import torch
from torchvision.models import vit_b_16

# 加载预训练 ViT
model = vit_b_16(pretrained=True)

# 推理
x = torch.randn(1, 3, 224, 224)
output = model(x)
print(output.shape)  # (1, 1000)

4. ViT vs CNN

特性CNNViT
感受野逐层增长全局(第一层)
参数量
数据需求多(需要大规模预训练)
计算效率
可解释性高(注意力可视化)

5. 小结

ViT 的优势:

1. 全局感受野:从第一层就看到全局
2. 可扩展性:可以处理任意分辨率
3. 可解释性:注意力权重可视化
4. 迁移学习:预训练效果好

2026 年趋势:
- 混合模型:CNN + Transformer
- 高效 ViT:减少计算量
- 多模态:Vision + Language

💡 记住:ViT 证明了 Transformer 不仅适用于 NLP,也适用于 CV。这是深度学习的重大转变。


🔗 扩展阅读