MAE (Masked Autoencoders):自监督学习的视觉预训练方法详解

引言

Masked Autoencoders (MAE) 是何恺明等人在2021年提出的一种革命性自监督学习方法。它巧妙地将自然语言处理中BERT的掩码语言建模思想迁移到计算机视觉领域。MAE 的核心思路是:随机遮挡图像中高达 75% 的小块(patch),然后训练一个模型来重建这些被遮挡的部分。在这个过程中,模型被迫学习图像的深层结构信息,从而获得强大的视觉表征能力。这一方法大幅推动了自监督学习在视觉领域的发展,并为 Vision Transformer(ViT)的预训练提供了一种全新的高效范式。

📂 所属阶段:第二阶段 — 深度学习视觉基础(视觉Transformer 篇)
🔗 相关章节:Swin Transformer · Vision-Language 多模态


1. MAE 核心思想与动机

1.1 自监督学习的兴起

自监督学习是当前深度学习的一个重要方向,它的目标是利用海量无标注数据,让模型自己构造监督信号进行预训练。这一思路受到以下几个动因的强烈推动:

  • 数据效率:避免昂贵且耗时的人工标注,可以直接复用互联网或工业界中大量公开或私有的无标签图像。
  • 成本效益:无需专业标注团队,大幅降低 AI 模型的开发成本。
  • 泛化能力:模型能够从无约束的自然数据中学习到更通用的底层视觉特征,而不仅仅局限于某个特定标注任务。
  • 可扩展性:天然适配大规模数据集与大模型训练。随着数据量和模型参数量增长,性能可以持续提升。

1.2 MAE 的创新点

MAE 的成功主要来源于三项关键技术创新:

  1. 不对称编码器-解码器架构:编码器只处理那些没有被遮挡的可见 patch(计算量降低约 75%),因此非常轻量高效;而解码器则需要处理所有的 patch,专门负责重建被遮挡的图像内容。
  2. 高比例随机掩码:采用 75% 的极端随机遮挡比例,迫使模型必须理解图像中不同区域之间的全局语义关联,而不能简单依赖局部的纹理填充。
  3. 轻量级像素级重建目标:直接预测被遮挡 patch 的原始 RGB 像素值,不需要额外引入预训练的 VAE(变分自编码器)或 Tokenizer 等辅助模块,实现简单且训练稳定。

2. MAE 架构详解

2.1 不对称编码器-解码器设计

编码器基于标准的 Vision Transformer(ViT),但只保留了对未掩码 patch 的处理逻辑。它负责将可见的 patch 压缩成高维特征表示,核心模块包括 patch 嵌入、位置编码和 Transformer 编码层。

以下代码展示了 MAE 编码器的具体实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    将224×224图像分割为14×14=196个16×16 patch,并嵌入为768维向量
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)  # (B, N, D)
        return x

class MAEEncoder(nn.Module):
    """
    MAE轻量级编码器,仅处理未被掩码的patch
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)

        # 类别token(用于下游分类)+ 所有patch的位置嵌入
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # ViT编码器层
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim*mlp_ratio),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, mask):
        # 1. Patch嵌入
        x = self.patch_embed(x)  # (B, 196, 768)

        # 2. 应用掩码:仅保留未被遮盖的49个patch
        x = x[~mask].reshape(x.shape[0], -1, x.shape[-1])  # (B, 49, 768)

        # 3. 添加类别token和对应位置的嵌入
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)

        pos_keep = self.pos_embed[:, 1:, :][~mask].reshape(x.shape[0], -1, x.shape[-1])
        pos_cls = self.pos_embed[:, :1, :]
        x = x + torch.cat([pos_cls, pos_keep], dim=1)

        # 4. Transformer编码
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

2.2 MAE 解码器设计

解码器虽然整体比编码器更轻量,但它需要处理包括被遮挡 patch 在内的全部 196 个 patch,因此对重建任务的关注更为集中。解码器的关键组件包括:

  • 一个线性映射,将编码器输出的高维特征映射到解码器使用的较低维度。
  • 一个可学习的掩码 token,用来占位所有被遮挡的 patch。
  • 完整的位置嵌入,保证无论 patch 是否被遮挡,模型都能知道它们在原始图像中的位置。
  • 若干 Transformer 层,用于融合可见 patch 的信息并推断被遮挡区域的内容。
  • 一个线性投影层,将每个 patch 的特征映射回原始像素值(例如 16×16×3 的 RGB 值)。
class MAEDecoder(nn.Module):
    """
    MAE重建解码器,处理所有196个patch
    """
    def __init__(self, num_patches=196, patch_size=16, embed_dim=768, decoder_embed_dim=512,
                 decoder_depth=8, decoder_num_heads=16):
        super().__init__()
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)

        # 掩码token(代表被遮盖的patch)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)

        # 所有patch(含cls)的解码器位置嵌入
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)

        # 解码器Transformer层
        self.decoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=decoder_embed_dim, nhead=decoder_num_heads, dim_feedforward=int(decoder_embed_dim*4),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(decoder_depth)
        ])

        # 重建每个patch的RGB像素值
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)

    def forward(self, x, ids_restore):
        # 1. 维度降维到解码器嵌入维度
        x = self.decoder_embed(x)  # (B, 50, 512)

        # 2. 拼接可见patch和掩码token,并恢复原始196个patch的顺序
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_no_cls = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_no_cls = torch.gather(x_no_cls, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_no_cls], dim=1)

        # 3. 添加完整位置嵌入并解码
        x = x + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # 4. 移除cls token,重建像素
        x = self.decoder_pred(x[:, 1:, :])
        return x

3. 掩码策略与完整模型

3.1 随机高比例掩码实现

MAE 的掩码策略看似简单,但对训练至关重要。具体做法是:对每个样本,随机打乱所有 patch 的顺序,然后选择前面的部分作为保留的可见 patch,剩下的为遮挡 patch。为了能够让解码器正确恢复原始图像块的位置,还需要保存一个“还原索引”,用于将被遮挡的 token 放回原始顺序。

def generate_random_mask(B, N, mask_ratio=0.75):
    """
    生成随机高比例掩码
    """
    len_keep = int(N * (1 - mask_ratio))

    # 生成随机噪声排序
    noise = torch.rand(B, N)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # 创建掩码:True表示被遮盖
    mask = torch.ones([B, N])
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore).bool()
    return mask, ids_restore

3.2 MAE 完整模型与训练流程

完整模型将编码器、解码器以及损失计算封装在一起。训练时,损失仅计算在被遮挡的 patch 上,并且默认采用标准化像素损失(即对每个 patch 的像素做均值方差归一化后再计算 MSE),这可以进一步提升模型的稳定性和最终性能。

class MaskedAutoencoder(nn.Module):
    """
    完整MAE模型
    """
    def __init__(self, img_size=224, patch_size=16, mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.norm_pix_loss = True

        # 编码器与解码器(使用默认ViT-Base配置)
        self.encoder = MAEEncoder()
        self.decoder = MAEDecoder()

    def patchify(self, imgs):
        """将图像分割为patch"""
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = x.permute(0, 2, 4, 3, 5, 1).flatten(1, 2).flatten(2, 4)
        return x

    def forward_loss(self, imgs, pred, mask):
        """仅计算被掩码patch的标准化像素损失"""
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6)**0.5
        loss = (pred - target)**2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def forward(self, imgs):
        mask, ids_restore = generate_random_mask(imgs.shape[0], self.encoder.patch_embed.num_patches, self.mask_ratio)
        latent = self.encoder(imgs, mask)
        pred = self.decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

4. 预训练与下游应用

4.1 预训练要点

在实际预训练中,推荐使用 AdamW 优化器,并配合学习率预热余弦退火策略。数据增强只需使用简单的随机缩放裁剪水平翻转,无需复杂的数据增广操作,因为 MAE 自身的高比例掩码已经提供了极强的正则化效果。

4.2 微调步骤

预训练完成后,将模型应用到下游任务(如图像分类、目标检测等)一般分为以下步骤:

  1. 提取编码器:丢弃训练时使用的解码器,只保留 MAE 的 ViT 编码器部分。
  2. 添加任务头:例如在 ImageNet 分类任务中,在编码器输出的类别 token 后接一个线性层,映射到 1000 个类别。
  3. 微调策略:可以先用线性探测(Linear Probe)的方式,冻结编码器参数只训练分类头;然后解冻全部参数进行端到端的全量微调,以获得最佳效果。

4.3 使用 timm 库加载预训练模型

借助 timm 库,你可以非常方便地加载预训练好的 MAE 模型并提取图像特征,无需从头编写上述所有代码。

import torch
import timm

# 加载预训练MAE ViT-Base模型
# num_classes=0 表示仅返回特征,不附加分类头
model = timm.create_model('mae_vit_base_patch16_224', pretrained=True, num_classes=0)

# 提取图像特征
model.eval()
with torch.no_grad():
    features = model(torch.randn(2, 3, 224, 224))  # (2, 768)

总结

MAE 通过高比例随机掩码 + 不对称编码器-解码器 + 像素级重建这一简洁组合,成功地将 NLP 领域的掩码建模思想迁移到了计算机视觉中。它显著提升了 Vision Transformer 在下游任务中的性能(例如,ImageNet Top-1 准确率从纯监督 ViT-B 的 82.2% 提升到了 MAE 预训练后 ViT-B 的 83.6%)。该方法实现简单、数据效率高,目前已经成为现代视觉 Transformer 预训练的一种标配范式。

💡 扩展阅读