#Vision Transformer (ViT) 详解:图像切片、Patch Embedding
📂 所属阶段:第四阶段 — 视觉新范式(Transformer 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer
#1. ViT 核心思想
ViT = Vision Transformer
创新:将图像当作文本处理
步骤:
1. 将图像分成 16×16 的 Patches
2. 线性投影得到 Patch Embeddings
3. 加上位置编码
4. 输入 Transformer#2. ViT 实现
import torch
import torch.nn as nn
from einops import rearrange
class ViT(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12):
super().__init__()
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.patch_embedding = nn.Linear(patch_dim, dim)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4),
num_layers=depth
)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
# x: (B, 3, 224, 224)
# 分割成 Patches
patches = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
# Patch Embedding
x = self.patch_embedding(patches)
# 添加 CLS token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 位置编码
x = x + self.pos_embedding
# Transformer
x = self.transformer(x)
# 取 CLS token
x = x[:, 0]
# 分类
x = self.mlp_head(x)
return x#3. 使用预训练 ViT
import torch
from torchvision.models import vit_b_16
# 加载预训练 ViT
model = vit_b_16(pretrained=True)
# 推理
x = torch.randn(1, 3, 224, 224)
output = model(x)
print(output.shape) # (1, 1000)#4. ViT vs CNN
| 特性 | CNN | ViT |
|---|---|---|
| 感受野 | 逐层增长 | 全局(第一层) |
| 参数量 | 少 | 多 |
| 数据需求 | 少 | 多(需要大规模预训练) |
| 计算效率 | 高 | 低 |
| 可解释性 | 低 | 高(注意力可视化) |
#5. 小结
ViT 的优势:
1. 全局感受野:从第一层就看到全局
2. 可扩展性:可以处理任意分辨率
3. 可解释性:注意力权重可视化
4. 迁移学习:预训练效果好
2026 年趋势:
- 混合模型:CNN + Transformer
- 高效 ViT:减少计算量
- 多模态:Vision + Language💡 记住:ViT 证明了 Transformer 不仅适用于 NLP,也适用于 CV。这是深度学习的重大转变。
🔗 扩展阅读

