Swin Transformer:滑动窗口机制与层级特征提取详解

Swin Transformer是微软亚洲研究院2021年提出的革命性视觉Transformer基准架构,通过局部滑动窗口和层级结构,完美解决了原始ViT(Vision Transformer)在计算效率多尺度特征提取上的致命缺陷,已成为图像分类、目标检测、语义分割等全视觉任务的首选骨干网络。

📂 所属阶段:第二阶段 — 深度学习视觉基础(视觉Transformer 补充)
🔗 前置阅读:Vision Transformer (ViT) 详解 · 注意力机制


1. 核心创新:解决ViT的两大痛点

原始ViT把图像视为“一堆独立的扁平 patch”,全局注意力导致计算量随 token 数量平方级增长,且单一分辨率结构缺少CNN式的多尺度归纳偏置

1.1 对比ViT的改进

维度ViTSwin Transformer
自注意力范围全局局部固定窗口
计算复杂度与 token 数量的平方成正比与 token 数量成线性关系(窗口大小固定)
特征层级单一分辨率4层级(类似ResNet的Stem+3次下采样)
跨窗口信息流动无(仅全局间接)周期性移位窗口(Shifted Window)直接实现
高分辨率处理效率极低良好

2. 核心机制详解

2.1 窗口注意力(W-MSA)

全局注意力改为在不重叠的局部窗口内计算,复杂度直接降维。同时引入相对位置编码保留局部空间关系。

import torch
import torch.nn as nn
from einops import rearrange

class WindowAttention(nn.Module):
    """窗口注意力模块(W-MSA/SW-MSA的核心)"""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # 初始化相对位置偏置表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

        # 预计算相对位置索引(避免重复计算)
        coords_h, coords_w = torch.arange(window_size[0]), torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # (2, Wh, Ww)
        coords_flat = torch.flatten(coords, 1)  # (2, Nw=Nw=Nw²)
        relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]  # (2, Nw, Nw)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # 偏移到非负整数范围
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        # 合并为一维索引
        relative_coords[:, :, 0] *= 2*window_size[1] - 1
        self.register_buffer("relative_position_index", relative_coords.sum(-1))  # (Nw, Nw)

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: (num_windows*B, Nw, C)
            mask: (num_windows, Nw, Nw) or None
        Returns:
            x: (num_windows*B, Nw, C)
        """
        B_, Nw, C = x.shape
        qkv = self.qkv(x).reshape(B_, Nw, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q *= self.scale
        attn = q @ k.transpose(-2, -1)

        # 添加相对位置偏置
        rel_pos_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        rel_pos_bias = rel_pos_bias.view(Nw, Nw, -1).permute(2,0,1).contiguous()
        attn += rel_pos_bias.unsqueeze(0)

        # 应用移位窗口的注意力掩码
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_//nW, nW, self.num_heads, Nw, Nw) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, Nw, Nw)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1,2).reshape(B_, Nw, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

2.2 移位窗口(SW-MSA)

仅用W-MSA的话,窗口间完全无信息交换,无法模拟全局感受野。Swin通过周期性窗口移位实现跨窗口连接,且用掩码避免移位后的无效注意力。

def window_partition(x, window_size):
    """(B, H, W, C) → (num_windows*B, window_size, window_size, C)"""
    B, H, W, C = x.shape
    x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
    return x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """(num_windows*B, window_size, window_size, C) → (B, H, W, C)"""
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    return x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)

class SwinTransformerBlock(nn.Module):
    """Swin基本块:W-MSA + FFN → SW-MSA + FFN(相邻块交替使用W/SW)"""
    def __init__(self, dim, input_res, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., drop_path=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_res = input_res  # (H, W)
        self.window_size = window_size
        self.shift_size = shift_size
        if min(input_res) <= window_size:
            self.shift_size, self.window_size = 0, min(input_res)

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, (self.window_size,)*2, num_heads)
        self.drop_path = nn.Dropout(drop_path) if drop_path>0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim)
        )

        # 预计算SW-MSA的注意力掩码
        if self.shift_size > 0:
            H, W = self.input_res
            img_mask = torch.zeros(1, H, W, 1)
            # 分割移位后的图像为9个区域(仅边界区域需要掩码)
            h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
            w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = window_partition(img_mask, window_size).view(-1, window_size**2)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.input_res
        assert L == H*W

        shortcut = x
        x = self.norm1(x).view(B, H, W, C)

        # 周期性移位(SW-MSA用)
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1,2))

        # 窗口分割→注意力→窗口还原
        x = window_partition(x, self.window_size).view(-1, self.window_size**2, C)
        x = self.attn(x, self.attn_mask)
        x = window_reverse(x.view(-1, self.window_size, self.window_size, C), self.window_size, H, W)

        # 逆转移位(SW-MSA用)
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1,2))

        x = x.view(B, H*W, C)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

2.3 Patch Merging(层级结构核心)

类似CNN的池化+通道融合,实现分辨率减半、通道数加倍,构建多尺度特征金字塔:

  1. 取相邻2×2的patch
  2. 拼接通道(4C)
  3. 线性投影降维到2C
class PatchMerging(nn.Module):
    def __init__(self, input_res, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_res = input_res
        self.norm = norm_layer(4*dim)
        self.reduction = nn.Linear(4*dim, 2*dim, bias=False)

    def forward(self, x):
        H, W = self.input_res
        B, L, C = x.shape
        assert L == H*W and H%2==0 and W%2==0

        x = x.view(B, H, W, C)
        # 取2×2相邻patch
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0,x1,x2,x3], -1).view(B, -1, 4*C)
        return self.reduction(self.norm(x))

3. 预训练模型快速上手

最常用的Swin预训练模型加载方式有两种:timm库(简洁高效)和Hugging Face Transformers(更通用)。

3.1 使用timm库

import torch
import timm

# 查看所有可用的Swin模型
# print(timm.list_models('swin*', pretrained=True))

# 加载预训练分类模型(Base版本)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=1000)
model.eval()

# 推理示例
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(input_tensor)  # (1, 1000) ImageNet分类输出
    print(f"分类输出形状: {output.shape}")

# 获取中间多尺度特征(用于检测/分割)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, features_only=True)
with torch.no_grad():
    features = model(input_tensor)  # 4个不同尺度的特征
    for i, feat in enumerate(features):
        print(f"Stage {i+1} 特征形状: {feat.shape}")

3.2 使用Hugging Face

from transformers import SwinImageProcessor, SwinForImageClassification
from PIL import Image
import requests

# 加载处理器和模型
processor = SwinImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model = SwinForImageClassification.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model.eval()

# 处理真实图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")

# 推理并预测
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print(f"预测类别: {model.config.id2label[predicted_class_idx]}")

4. 总结与学习建议

Swin Transformer的三大核心:

  1. 局部滑动窗口:使计算量与 token 数量成线性关系
  2. 移位窗口+掩码:实现跨窗口全局信息流动
  3. Patch Merging:构建多尺度特征金字塔

学习建议:

  • 先通过timm/Hugging Face跑通预训练模型
  • 重点理解SW-MSA的掩码生成和周期性移位
  • 可以结合UperNet/DINOv2尝试下游任务实践

💡 推荐阅读