"""
ViT:全局注意力,复杂度 O(n²)
Swin:局部注意力,复杂度 O(n)
窗口大小:7×7
相邻层错位:实现全局连接
"""
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.attn = nn.Softmax(dim=-1)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
# x: (B*num_windows, N, C)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * (C // self.num_heads) ** -0.5
attn = self.attn(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x