引言
Vision Transformer(ViT)是计算机视觉领域的一颗重磅炸弹。它大胆地将自然语言处理中大获成功的 Transformer 架构直接用于图像分类,并且取得了令人瞩目的效果。ViT 的核心思路简单而巧妙:把图像切成一个个小块(patch),像对待句子里的单词一样,把这些图像块“喂”给 Transformer,让模型自行学习哪些部分值得关注。
本教程将用通俗易懂的语言,一步步拆解 ViT 的核心设计:图像怎么切、Patch Embedding 怎么做、位置编码怎么加、多头注意力怎么算……除了理论讲解,还会给出基于 PyTorch 的实现代码,并演示如何使用预训练模型。不管你是刚刚接触视觉 Transformer,还是想查缺补漏,希望这篇文章都能帮助你轻松上手。
📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer
1.1 把图片当作“语言”来理解
传统卷积神经网络(CNN)通过不断堆叠卷积层,慢慢扩大感受野,从局部纹理逐步聚合出全局语义。而 Vision Transformer 一上来就看了全图,每个图像块都可以直接和其他所有块进行交互。这种“全局视野”正是其强大之处。
ViT 的主要创新点可以概括为三句话:
- 图像分块:把整张图切成固定大小的小方块(patch),每个方块展开后就像 NLP 中的一个“token”(单词)。
- 序列建模:把这些 patch 当作一个序列,送进 Transformer 编码器,利用自注意力机制寻找它们之间的关系。
- 全局连接:从第一层开始,每个 patch 都能看见所有其他 patch,天然适合捕捉长距离依赖。
1.2 ViT 的发展时间线
- 2017 年:Transformer 架构在论文 “Attention Is All You Need” 中被提出。
- 2018 年:BERT 等模型让 Transformer 在 NLP 领域大放异彩。
- 2020 年:ViT 发布,首次证明纯 Transformer 也能在图像分类上超越 CNN。
- 2021 年至今:Swin Transformer、PVT 等层次化 ViT 出现,并逐步在检测、分割任务中普及。
2. ViT 架构详解
ViT 的工作流程可以归纳为四个步骤:
- 把图像切成 patch,并映射成固定长度的向量(Patch Embedding)。
- 在序列最前面添加一个专门用于分类的 class token。
- 加入位置编码,让模型知道每个 patch 原来在图像中的位置。
- 送入 Transformer 编码器,最终用 class token 的输出进行分类。
下面我们详细拆解每一个步骤,并结合代码加深理解。
2.1 图像分块与 Patch Embedding
这是 ViT 的第一步,也是将图像从“网格结构”转换成“序列”的关键。
假设输入图像的大小为 224×224,我们设定 patch 大小为 16×16。这样,整张图在水平和垂直方向上各被切成 14 块,总共 14×14 = 196 个 patch。每个 patch 包含 16×16×3 = 768 个像素值,这些像素值会通过一个线性投影层映射到一个新的向量空间(维度保持不变或改变均可,ViT 原论文中保持 768 维)。
Patch Embedding 过程示意:
输入图像:(B, 3, 224, 224)
分割并展开:(B, 196, 768)
每个 patch 向量长度:768
下面是使用 PyTorch 和 einops 实现图像切块与投影的代码:
import torch
import torch.nn as nn
from einops import rearrange
class ImageToPatches(nn.Module):
"""
图像到 patch 的转换模块
"""
def __init__(self, image_size=224, patch_size=16, channels=3):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
# 线性投影层:将每个 patch 的原始像素映射到目标维度
self.projection = nn.Linear(self.patch_dim, self.patch_dim)
def forward(self, x):
"""
x: (batch, channels, height, width)
return: (batch, num_patches, patch_dim)
"""
batch_size, channels, height, width = x.shape
# 简单校验输入尺寸
assert height == self.image_size and width == self.image_size, \
f"输入图像尺寸应为 ({self.image_size}, {self.image_size})"
# 使用 einops 的 rearrange 优雅切分
x = rearrange(
x,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=self.patch_size,
p2=self.patch_size
)
# 线性投影
x = self.projection(x)
return x
代码中,rearrange 将原本 (B, C, H, W) 的张量重新排列成 (B, num_patches, patch_dim),一行代码完成切分与扁平化,干净利落。
2.2 ViT 完整实现
掌握了 Patch Embedding 后,我们就可以搭建完整的 Vision Transformer 模型了。下面是一个清晰可读的 PyTorch 实现:
import torch
import torch.nn as nn
class VisionTransformer(nn.Module):
"""
Vision Transformer 完整实现
"""
def __init__(
self,
image_size=224,
patch_size=16,
num_classes=1000,
dim=768, # 嵌入维度
depth=12, # Transformer 层数
heads=12, # 注意力头数
mlp_dim=3072, # 前馈网络扩展维度
dropout=0.1,
emb_dropout=0.1
):
super(VisionTransformer, self).__init__()
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
# 用卷积一步完成切块+投影 (卷积核大小和步长都等于 patch_size)
self.to_patch_embedding = nn.Sequential(
nn.Conv2d(3, patch_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(start_dim=2), # (B, patch_dim, H, W) -> (B, patch_dim, num_patches)
nn.Linear(patch_dim, dim)
)
# 可学习的分类 token 和位置编码
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.dropout = nn.Dropout(emb_dropout)
# Transformer 编码器(使用 PyTorch 官方的 TransformerEncoderLayer)
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=dropout,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 最后的分类头
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
# 1. 获取 patch 嵌入
x = self.to_patch_embedding(img) # (B, num_patches, dim)
b, n, _ = x.shape
# 2. 在最前面拼接 class token
cls_tokens = self.cls_token.repeat(b, 1, 1) # (B, 1, dim)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, dim)
# 3. 加上位置编码
x = x + self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# 4. 送入 Transformer 编码器
x = self.transformer(x)
# 5. 取出 class token 对应位置的输出进行分类
cls_output = x[:, 0] # (B, dim)
output = self.mlp_head(cls_output)
return output
这段代码基本复刻了 ViT 的结构,建议初学者逐行阅读,理解每一步张量形状的变化。
2.3 多头自注意力机制详解
自注意力是 Transformer 的核心,也是 ViT “看全图”能力的关键。我在这里给出一个更贴近原公式的实现,帮助大家理解内部计算过程:
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制
"""
def __init__(self, d_model=768, num_heads=12, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 生成 Q, K, V 的线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# 缩放因子
self.scale = torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 线性投影
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 拆分成多头: (B, seq_len, num_heads, d_k) -> (B, num_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
# 加权求和
output = torch.matmul(attention, V) # (B, num_heads, seq_len, d_k)
# 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# 最终线性变换
output = self.W_o(output)
return output
通过这段代码可以直观地看到:注意力机制本质上就是让序列中的每个位置,根据自身与所有位置的相似度,去加权聚合所有位置的信息。
3. 位置编码与 Class Token
3.1 位置编码:让模型知道“你在哪儿”
Transformer 本身不具备感知输入顺序的能力,因此必须为每一个 patch 注入位置信息。ViT 中使用的是可学习的位置编码,即直接初始化一组参数,让模型在训练过程中自行调整。
常用的位置编码类型有:
- 可学习式(Learnable):ViT 的默认做法,简单直接。
- 正余弦式(Sinusoidal):无需额外参数,但 ViT 中效果稍逊。
- 二维位置编码(2D PE):保留 patch 的横纵坐标信息,更适合图像。
- 旋转位置编码(Rotary PE):在大模型和长序列中表现出色。
在 ViT 中,位置编码向量会直接加到 patch embedding 上,形状为 (num_patches + 1, dim)。之所以需要 +1,是因为还要为 class token 保留一个位置。
3.2 Class Token:一个“总揽全局”的特殊角色
ViT 在序列最前面放置了一个可学习的 class token。这个 token 不来自任何图像 patch,但经过多层 Transformer 后,它会逐步汇总所有 patch 的信息,成为整张图片的“代言人”。最终我们只需要提取 class token 的输出向量,送入分类头即可完成分类。
这样设计的好处是:
- 集中聚合图像级别的语义信息。
- 避免了额外设计全局池化层的需要。
- 与 BERT 中的
[CLS] token 一脉相承,容易理解和迁移。
4. ViT 的变体与改进
4.1 DeiT:数据高效版的 ViT
DeiT 主要解决了 ViT 依赖海量数据预训练的问题。它通过知识蒸馏,引入一个教师模型(通常是卷积网络)来指导 ViT 学习,使得即使只用百万级的数据集也能训练出不错的模型。此外,它使用了更强的数据增强和正则化手段。
4.2 更高效的 ViT 变体
纯 ViT 的计算量与 patch 数的平方相关,这让它在高分辨率图像上成本较高。后续研究提出了许多轻量级或高效版本,例如:
- MobileViT:融合了卷积的局部优势,适合移动设备。
- PVT(金字塔视觉 Transformer):采用逐渐缩小的特征图,形成类似 CNN 的金字塔结构。
- Swin Transformer:在局部窗口内计算自注意力,通过移动窗口实现跨窗口交互,大幅降低计算量。
- Twins:结合空间注意力和序列自注意力,兼顾局部与全局。
这些变体让 ViT 不仅在分类上表现优异,也能高效地用于检测、分割等下游任务。
5. 使用预训练模型
作为开发者,我们大多数情况下不需要从头训练 ViT。PyTorch 和 Hugging Face 都提供了开箱即用的预训练模型。
5.1 使用 PyTorch 官方模型
import torch
from torchvision import models, transforms
# 加载预训练 ViT-B/16
model = models.vit_b_16(weights='IMAGENET1K_V1')
model.eval()
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# 假设已经得到 input_tensor,Shape = (1, 3, 224, 224)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
# 加载处理器和模型
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# 读取图片
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# 推理
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
Hugging Face 的接口十分简洁,上手极快,非常适合快速验证想法或构建原型。
6. ViT 与 CNN 的对比
ViT 和 CNN 并不是非此即彼的关系,它们各有擅长,下面从几个关键维度进行对比:
选型建议:
- 小数据集、实时应用、移动端部署,依然是 CNN 的优势领域。
- 数据充足、追求更高准确率、需要多模态融合的场景,ViT 会更亮眼。
7. 实践技巧与调优
训练或微调一个 ViT,通常可以借助以下技巧:
- 优化器:优先使用 AdamW,搭配 weight decay。
- 学习率调度:warmup + 余弦退火,让训练更稳定。
- 数据增强:RandAugment、Mixup、CutMix 等可以有效提升泛化能力。
- 正则化:Dropout、Stochastic Depth、label smoothing。
- 知识蒸馏:用大模型或者 CNN 教师模型指导小 ViT,效果显著。
如果要部署到生产环境:
- 模型量化:INT8 量化可大幅减小体积,加速推理。
- 混合精度训练:节省显存,提升速度。
- 稀疏注意力:通过限制注意力范围降低计算量。
- 蒸馏小模型:在移动端也能获得接近大模型的性能。
相关教程
ViT 是计算机视觉的一个重要里程碑。建议先花时间理解 Transformer 的基本原理(尤其是自注意力),再回过头来看 ViT 的实现会轻松很多。同时,动手跑一跑代码、尝试加载预训练模型进行预测,也是快速建立直觉的好方法。
8. 总结
Vision Transformer 向世界证明了:图像分类不需要卷积也能做得很好,甚至更好。它的核心创新点可以归纳为三个环节:
- 图像切块(Patch):将图像变成序列,打通了视觉和语言的建模壁垒。
- 全局自注意力:每一个 patch 都能直接建模全局依赖,捕捉长距离特征。
- 可扩展性:模型可以像堆乐高一样增加深度和宽度,借助大量数据进一步提升性能。
无论你从事计算机视觉研究,还是想将前沿技术落地到产品,Vision Transformer 都是一项值得认真理解的技术。
💡 重要提醒:ViT 的出现,开启了视觉与语言模型统一建模的新时代,也催生了 CLIP、DALL·E 等一系列现象级多模态模型。
🔗 扩展阅读