Vision Transformer:从图像切片到Patch Embedding详解

引言

Vision Transformer(ViT)是计算机视觉领域的一颗重磅炸弹。它大胆地将自然语言处理中大获成功的 Transformer 架构直接用于图像分类,并且取得了令人瞩目的效果。ViT 的核心思路简单而巧妙:把图像切成一个个小块(patch),像对待句子里的单词一样,把这些图像块“喂”给 Transformer,让模型自行学习哪些部分值得关注。

本教程将用通俗易懂的语言,一步步拆解 ViT 的核心设计:图像怎么切、Patch Embedding 怎么做、位置编码怎么加、多头注意力怎么算……除了理论讲解,还会给出基于 PyTorch 的实现代码,并演示如何使用预训练模型。不管你是刚刚接触视觉 Transformer,还是想查缺补漏,希望这篇文章都能帮助你轻松上手。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer


1. Vision Transformer 的核心思想

1.1 把图片当作“语言”来理解

传统卷积神经网络(CNN)通过不断堆叠卷积层,慢慢扩大感受野,从局部纹理逐步聚合出全局语义。而 Vision Transformer 一上来就看了全图,每个图像块都可以直接和其他所有块进行交互。这种“全局视野”正是其强大之处。

ViT 的主要创新点可以概括为三句话:

  1. 图像分块:把整张图切成固定大小的小方块(patch),每个方块展开后就像 NLP 中的一个“token”(单词)。
  2. 序列建模:把这些 patch 当作一个序列,送进 Transformer 编码器,利用自注意力机制寻找它们之间的关系。
  3. 全局连接:从第一层开始,每个 patch 都能看见所有其他 patch,天然适合捕捉长距离依赖。

1.2 ViT 的发展时间线

  • 2017 年:Transformer 架构在论文 “Attention Is All You Need” 中被提出。
  • 2018 年:BERT 等模型让 Transformer 在 NLP 领域大放异彩。
  • 2020 年:ViT 发布,首次证明纯 Transformer 也能在图像分类上超越 CNN。
  • 2021 年至今:Swin Transformer、PVT 等层次化 ViT 出现,并逐步在检测、分割任务中普及。

2. ViT 架构详解

ViT 的工作流程可以归纳为四个步骤:

  1. 把图像切成 patch,并映射成固定长度的向量(Patch Embedding)。
  2. 在序列最前面添加一个专门用于分类的 class token。
  3. 加入位置编码,让模型知道每个 patch 原来在图像中的位置。
  4. 送入 Transformer 编码器,最终用 class token 的输出进行分类。

下面我们详细拆解每一个步骤,并结合代码加深理解。

2.1 图像分块与 Patch Embedding

这是 ViT 的第一步,也是将图像从“网格结构”转换成“序列”的关键。

假设输入图像的大小为 224×224,我们设定 patch 大小为 16×16。这样,整张图在水平和垂直方向上各被切成 14 块,总共 14×14 = 196 个 patch。每个 patch 包含 16×16×3 = 768 个像素值,这些像素值会通过一个线性投影层映射到一个新的向量空间(维度保持不变或改变均可,ViT 原论文中保持 768 维)。

Patch Embedding 过程示意:
输入图像:(B, 3, 224, 224)
分割并展开:(B, 196, 768)
每个 patch 向量长度:768

下面是使用 PyTorch 和 einops 实现图像切块与投影的代码:

import torch
import torch.nn as nn
from einops import rearrange

class ImageToPatches(nn.Module):
    """
    图像到 patch 的转换模块
    """
    def __init__(self, image_size=224, patch_size=16, channels=3):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2

        # 线性投影层:将每个 patch 的原始像素映射到目标维度
        self.projection = nn.Linear(self.patch_dim, self.patch_dim)

    def forward(self, x):
        """
        x: (batch, channels, height, width)
        return: (batch, num_patches, patch_dim)
        """
        batch_size, channels, height, width = x.shape

        # 简单校验输入尺寸
        assert height == self.image_size and width == self.image_size, \
            f"输入图像尺寸应为 ({self.image_size}, {self.image_size})"

        # 使用 einops 的 rearrange 优雅切分
        x = rearrange(
            x,
            'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
            p1=self.patch_size,
            p2=self.patch_size
        )

        # 线性投影
        x = self.projection(x)

        return x

代码中,rearrange 将原本 (B, C, H, W) 的张量重新排列成 (B, num_patches, patch_dim),一行代码完成切分与扁平化,干净利落。

2.2 ViT 完整实现

掌握了 Patch Embedding 后,我们就可以搭建完整的 Vision Transformer 模型了。下面是一个清晰可读的 PyTorch 实现:

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    """
    Vision Transformer 完整实现
    """
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_classes=1000,
        dim=768,             # 嵌入维度
        depth=12,            # Transformer 层数
        heads=12,            # 注意力头数
        mlp_dim=3072,        # 前馈网络扩展维度
        dropout=0.1,
        emb_dropout=0.1
    ):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        # 用卷积一步完成切块+投影 (卷积核大小和步长都等于 patch_size)
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(3, patch_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(start_dim=2),            # (B, patch_dim, H, W) -> (B, patch_dim, num_patches)
            nn.Linear(patch_dim, dim)
        )

        # 可学习的分类 token 和位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        # Transformer 编码器(使用 PyTorch 官方的 TransformerEncoderLayer)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # 最后的分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 1. 获取 patch 嵌入
        x = self.to_patch_embedding(img)    # (B, num_patches, dim)
        b, n, _ = x.shape

        # 2. 在最前面拼接 class token
        cls_tokens = self.cls_token.repeat(b, 1, 1)   # (B, 1, dim)
        x = torch.cat([cls_tokens, x], dim=1)         # (B, num_patches+1, dim)

        # 3. 加上位置编码
        x = x + self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        # 4. 送入 Transformer 编码器
        x = self.transformer(x)

        # 5. 取出 class token 对应位置的输出进行分类
        cls_output = x[:, 0]    # (B, dim)
        output = self.mlp_head(cls_output)

        return output

这段代码基本复刻了 ViT 的结构,建议初学者逐行阅读,理解每一步张量形状的变化。

2.3 多头自注意力机制详解

自注意力是 Transformer 的核心,也是 ViT “看全图”能力的关键。我在这里给出一个更贴近原公式的实现,帮助大家理解内部计算过程:

import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制
    """
    def __init__(self, d_model=768, num_heads=12, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 生成 Q, K, V 的线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        # 缩放因子
        self.scale = torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        # 线性投影
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # 拆分成多头: (B, seq_len, num_heads, d_k) -> (B, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # 加权求和
        output = torch.matmul(attention, V)   # (B, num_heads, seq_len, d_k)

        # 合并多头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # 最终线性变换
        output = self.W_o(output)

        return output

通过这段代码可以直观地看到:注意力机制本质上就是让序列中的每个位置,根据自身与所有位置的相似度,去加权聚合所有位置的信息。


3. 位置编码与 Class Token

3.1 位置编码:让模型知道“你在哪儿”

Transformer 本身不具备感知输入顺序的能力,因此必须为每一个 patch 注入位置信息。ViT 中使用的是可学习的位置编码,即直接初始化一组参数,让模型在训练过程中自行调整。

常用的位置编码类型有:

  • 可学习式(Learnable):ViT 的默认做法,简单直接。
  • 正余弦式(Sinusoidal):无需额外参数,但 ViT 中效果稍逊。
  • 二维位置编码(2D PE):保留 patch 的横纵坐标信息,更适合图像。
  • 旋转位置编码(Rotary PE):在大模型和长序列中表现出色。

在 ViT 中,位置编码向量会直接加到 patch embedding 上,形状为 (num_patches + 1, dim)。之所以需要 +1,是因为还要为 class token 保留一个位置。

3.2 Class Token:一个“总揽全局”的特殊角色

ViT 在序列最前面放置了一个可学习的 class token。这个 token 不来自任何图像 patch,但经过多层 Transformer 后,它会逐步汇总所有 patch 的信息,成为整张图片的“代言人”。最终我们只需要提取 class token 的输出向量,送入分类头即可完成分类。

这样设计的好处是:

  • 集中聚合图像级别的语义信息。
  • 避免了额外设计全局池化层的需要。
  • 与 BERT 中的 [CLS] token 一脉相承,容易理解和迁移。

4. ViT 的变体与改进

4.1 DeiT:数据高效版的 ViT

DeiT 主要解决了 ViT 依赖海量数据预训练的问题。它通过知识蒸馏,引入一个教师模型(通常是卷积网络)来指导 ViT 学习,使得即使只用百万级的数据集也能训练出不错的模型。此外,它使用了更强的数据增强和正则化手段。

4.2 更高效的 ViT 变体

纯 ViT 的计算量与 patch 数的平方相关,这让它在高分辨率图像上成本较高。后续研究提出了许多轻量级或高效版本,例如:

  • MobileViT:融合了卷积的局部优势,适合移动设备。
  • PVT(金字塔视觉 Transformer):采用逐渐缩小的特征图,形成类似 CNN 的金字塔结构。
  • Swin Transformer:在局部窗口内计算自注意力,通过移动窗口实现跨窗口交互,大幅降低计算量。
  • Twins:结合空间注意力和序列自注意力,兼顾局部与全局。

这些变体让 ViT 不仅在分类上表现优异,也能高效地用于检测、分割等下游任务。


5. 使用预训练模型

作为开发者,我们大多数情况下不需要从头训练 ViT。PyTorch 和 Hugging Face 都提供了开箱即用的预训练模型。

5.1 使用 PyTorch 官方模型

import torch
from torchvision import models, transforms

# 加载预训练 ViT-B/16
model = models.vit_b_16(weights='IMAGENET1K_V1')
model.eval()

# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# 假设已经得到 input_tensor,Shape = (1, 3, 224, 224)
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

5.2 使用 Hugging Face Transformers

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

# 加载处理器和模型
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# 读取图片
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

# 推理
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])

Hugging Face 的接口十分简洁,上手极快,非常适合快速验证想法或构建原型。


6. ViT 与 CNN 的对比

ViT 和 CNN 并不是非此即彼的关系,它们各有擅长,下面从几个关键维度进行对比:

特性CNNViT
感受野从小范围逐渐扩大,由局部走向全局从一开始就是全局视野
参数效率通常参数量较小,计算效率高参数量较大,需要更多数据
数据需求小数据集上容易获得较好结果依赖大规模预训练或知识蒸馏
可解释性特征图不易直观理解注意力权重可直接可视化
计算开销与图像尺寸呈线性关系随 patch 数量增加,平方级增长
归纳偏置强(平移不变性、局部相关性)弱(更通用,更接近原始数据)

选型建议:

  • 小数据集、实时应用、移动端部署,依然是 CNN 的优势领域。
  • 数据充足、追求更高准确率、需要多模态融合的场景,ViT 会更亮眼。

7. 实践技巧与调优

训练或微调一个 ViT,通常可以借助以下技巧:

  1. 优化器:优先使用 AdamW,搭配 weight decay。
  2. 学习率调度:warmup + 余弦退火,让训练更稳定。
  3. 数据增强:RandAugment、Mixup、CutMix 等可以有效提升泛化能力。
  4. 正则化:Dropout、Stochastic Depth、label smoothing。
  5. 知识蒸馏:用大模型或者 CNN 教师模型指导小 ViT,效果显著。

如果要部署到生产环境:

  • 模型量化:INT8 量化可大幅减小体积,加速推理。
  • 混合精度训练:节省显存,提升速度。
  • 稀疏注意力:通过限制注意力范围降低计算量。
  • 蒸馏小模型:在移动端也能获得接近大模型的性能。

相关教程

ViT 是计算机视觉的一个重要里程碑。建议先花时间理解 Transformer 的基本原理(尤其是自注意力),再回过头来看 ViT 的实现会轻松很多。同时,动手跑一跑代码、尝试加载预训练模型进行预测,也是快速建立直觉的好方法。

8. 总结

Vision Transformer 向世界证明了:图像分类不需要卷积也能做得很好,甚至更好。它的核心创新点可以归纳为三个环节:

  • 图像切块(Patch):将图像变成序列,打通了视觉和语言的建模壁垒。
  • 全局自注意力:每一个 patch 都能直接建模全局依赖,捕捉长距离特征。
  • 可扩展性:模型可以像堆乐高一样增加深度和宽度,借助大量数据进一步提升性能。

无论你从事计算机视觉研究,还是想将前沿技术落地到产品,Vision Transformer 都是一项值得认真理解的技术。

💡 重要提醒:ViT 的出现,开启了视觉与语言模型统一建模的新时代,也催生了 CLIP、DALL·E 等一系列现象级多模态模型。

🔗 扩展阅读