#Swin Transformer:滑动窗口机制与层级特征提取详解
Swin Transformer是微软亚洲研究院2021年提出的革命性视觉Transformer基准架构,通过局部滑动窗口和层级结构,完美解决了原始ViT(Vision Transformer)在计算效率和多尺度特征提取上的致命缺陷,已成为图像分类、目标检测、语义分割等全视觉任务的首选骨干网络。
📂 所属阶段:第二阶段 — 深度学习视觉基础(视觉Transformer 补充)
🔗 前置阅读:Vision Transformer (ViT) 详解 · 注意力机制
#1. 核心创新:解决ViT的两大痛点
原始ViT把图像视为“一堆独立的扁平 patch”,全局注意力导致计算量随 token 数量平方级增长,且单一分辨率结构缺少CNN式的多尺度归纳偏置。
#1.1 对比ViT的改进
| 维度 | ViT | Swin Transformer |
|---|---|---|
| 自注意力范围 | 全局 | 局部固定窗口 |
| 计算复杂度 | 与 token 数量的平方成正比 | 与 token 数量成线性关系(窗口大小固定) |
| 特征层级 | 单一分辨率 | 4层级(类似ResNet的Stem+3次下采样) |
| 跨窗口信息流动 | 无(仅全局间接) | 周期性移位窗口(Shifted Window)直接实现 |
| 高分辨率处理效率 | 极低 | 良好 |
#2. 核心机制详解
#2.1 窗口注意力(W-MSA)
全局注意力改为在不重叠的局部窗口内计算,复杂度直接降维。同时引入相对位置编码保留局部空间关系。
import torch
import torch.nn as nn
from einops import rearrange
class WindowAttention(nn.Module):
"""窗口注意力模块(W-MSA/SW-MSA的核心)"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # (Wh, Ww)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 初始化相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
# 预计算相对位置索引(避免重复计算)
coords_h, coords_w = torch.arange(window_size[0]), torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # (2, Wh, Ww)
coords_flat = torch.flatten(coords, 1) # (2, Nw=Nw=Nw²)
relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] # (2, Nw, Nw)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# 偏移到非负整数范围
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
# 合并为一维索引
relative_coords[:, :, 0] *= 2*window_size[1] - 1
self.register_buffer("relative_position_index", relative_coords.sum(-1)) # (Nw, Nw)
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: (num_windows*B, Nw, C)
mask: (num_windows, Nw, Nw) or None
Returns:
x: (num_windows*B, Nw, C)
"""
B_, Nw, C = x.shape
qkv = self.qkv(x).reshape(B_, Nw, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2]
q *= self.scale
attn = q @ k.transpose(-2, -1)
# 添加相对位置偏置
rel_pos_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
rel_pos_bias = rel_pos_bias.view(Nw, Nw, -1).permute(2,0,1).contiguous()
attn += rel_pos_bias.unsqueeze(0)
# 应用移位窗口的注意力掩码
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_//nW, nW, self.num_heads, Nw, Nw) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, Nw, Nw)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1,2).reshape(B_, Nw, C)
x = self.proj(x)
x = self.proj_drop(x)
return x#2.2 移位窗口(SW-MSA)
仅用W-MSA的话,窗口间完全无信息交换,无法模拟全局感受野。Swin通过周期性窗口移位实现跨窗口连接,且用掩码避免移位后的无效注意力。
def window_partition(x, window_size):
"""(B, H, W, C) → (num_windows*B, window_size, window_size, C)"""
B, H, W, C = x.shape
x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
return x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""(num_windows*B, window_size, window_size, C) → (B, H, W, C)"""
B = int(windows.shape[0] / (H*W/window_size/window_size))
x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
return x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
class SwinTransformerBlock(nn.Module):
"""Swin基本块:W-MSA + FFN → SW-MSA + FFN(相邻块交替使用W/SW)"""
def __init__(self, dim, input_res, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., drop_path=0., norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_res = input_res # (H, W)
self.window_size = window_size
self.shift_size = shift_size
if min(input_res) <= window_size:
self.shift_size, self.window_size = 0, min(input_res)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(dim, (self.window_size,)*2, num_heads)
self.drop_path = nn.Dropout(drop_path) if drop_path>0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim*mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim*mlp_ratio), dim)
)
# 预计算SW-MSA的注意力掩码
if self.shift_size > 0:
H, W = self.input_res
img_mask = torch.zeros(1, H, W, 1)
# 分割移位后的图像为9个区域(仅边界区域需要掩码)
h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size).view(-1, window_size**2)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
B, L, C = x.shape
H, W = self.input_res
assert L == H*W
shortcut = x
x = self.norm1(x).view(B, H, W, C)
# 周期性移位(SW-MSA用)
if self.shift_size > 0:
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1,2))
# 窗口分割→注意力→窗口还原
x = window_partition(x, self.window_size).view(-1, self.window_size**2, C)
x = self.attn(x, self.attn_mask)
x = window_reverse(x.view(-1, self.window_size, self.window_size, C), self.window_size, H, W)
# 逆转移位(SW-MSA用)
if self.shift_size > 0:
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1,2))
x = x.view(B, H*W, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x#2.3 Patch Merging(层级结构核心)
类似CNN的池化+通道融合,实现分辨率减半、通道数加倍,构建多尺度特征金字塔:
- 取相邻2×2的patch
- 拼接通道(4C)
- 线性投影降维到2C
class PatchMerging(nn.Module):
def __init__(self, input_res, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_res = input_res
self.norm = norm_layer(4*dim)
self.reduction = nn.Linear(4*dim, 2*dim, bias=False)
def forward(self, x):
H, W = self.input_res
B, L, C = x.shape
assert L == H*W and H%2==0 and W%2==0
x = x.view(B, H, W, C)
# 取2×2相邻patch
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0,x1,x2,x3], -1).view(B, -1, 4*C)
return self.reduction(self.norm(x))#3. 预训练模型快速上手
最常用的Swin预训练模型加载方式有两种:timm库(简洁高效)和Hugging Face Transformers(更通用)。
#3.1 使用timm库
import torch
import timm
# 查看所有可用的Swin模型
# print(timm.list_models('swin*', pretrained=True))
# 加载预训练分类模型(Base版本)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=1000)
model.eval()
# 推理示例
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(input_tensor) # (1, 1000) ImageNet分类输出
print(f"分类输出形状: {output.shape}")
# 获取中间多尺度特征(用于检测/分割)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, features_only=True)
with torch.no_grad():
features = model(input_tensor) # 4个不同尺度的特征
for i, feat in enumerate(features):
print(f"Stage {i+1} 特征形状: {feat.shape}")#3.2 使用Hugging Face
from transformers import SwinImageProcessor, SwinForImageClassification
from PIL import Image
import requests
# 加载处理器和模型
processor = SwinImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model = SwinForImageClassification.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model.eval()
# 处理真实图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# 推理并预测
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print(f"预测类别: {model.config.id2label[predicted_class_idx]}")#4. 总结与学习建议
Swin Transformer的三大核心:
- ✅ 局部滑动窗口:使计算量与 token 数量成线性关系
- ✅ 移位窗口+掩码:实现跨窗口全局信息流动
- ✅ Patch Merging:构建多尺度特征金字塔
学习建议:
- 先通过timm/Hugging Face跑通预训练模型
- 重点理解SW-MSA的掩码生成和周期性移位
- 可以结合UperNet/DINOv2尝试下游任务实践
💡 推荐阅读:

