Swin Transformer:滑动窗口、层级式特征提取

📂 所属阶段:第四阶段 — 视觉新范式(Transformer 篇)
🔗 相关章节:Vision Transformer (ViT) 详解 · MAE (Masked Autoencoders)


1. Swin Transformer 核心创新

Swin = Shifted Windows

改进 ViT 的两个问题:
1. 计算复杂度高(O(n²))
2. 缺少层级结构

解决方案:
- 滑动窗口:局部注意力
- 层级结构:多尺度特征

2. 滑动窗口机制

"""
ViT:全局注意力,复杂度 O(n²)
Swin:局部注意力,复杂度 O(n)

窗口大小:7×7
相邻层错位:实现全局连接
"""

import torch
import torch.nn as nn

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.attn = nn.Softmax(dim=-1)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        # x: (B*num_windows, N, C)
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (C // self.num_heads) ** -0.5
        attn = self.attn(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

3. 使用预训练 Swin

import torch
from timm.models import swin_base_patch4_window7_224

# 加载预训练 Swin Transformer
model = swin_base_patch4_window7_224(pretrained=True)

# 推理
x = torch.randn(1, 3, 224, 224)
output = model(x)
print(output.shape)  # (1, 1000)

4. Swin vs ViT

特性ViTSwin
注意力范围全局局部
计算复杂度O(n²)O(n)
层级结构
速度
准确率更高

5. 小结

Swin Transformer 优势:

1. 高效:局部注意力减少计算
2. 层级:多尺度特征提取
3. 灵活:可用于检测、分割等任务

2026 年应用:
- 分类:Swin 优于 ViT
- 检测:Swin 是主流
- 分割:Swin 效果最好

💡 记住:Swin Transformer 是 2026 年最实用的视觉模型。它兼具 CNN 的层级性和 Transformer 的全局性。


🔗 扩展阅读