import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ----------------------------
# 1. 核心基础模块
# ----------------------------
class PatchEmbedding(nn.Module):
"""
将图像分割成固定大小的patch并做线性嵌入
等价于先用16x16卷积步长16,再展平转置
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2 # 224/16=14 → 14²=196个patch
# 用卷积层实现高效的patch嵌入(避免手动拆图循环)
self.conv_proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (batch_size, 3, 224, 224)
x = self.conv_proj(x) # → (batch_size, 768, 14, 14)
x = x.flatten(2) # → (batch_size, 768, 196)
x = x.transpose(1, 2) # → (batch_size, 196, 768)
return x
class MultiHeadSelfAttention(nn.Module):
"""
标准多头自注意力(MHSA)
"""
def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
assert self.head_dim * n_heads == embed_dim, "embed_dim必须能被n_heads整除"
# 一次性生成Q、K、V的线性层(比分开写高效)
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
self.attn_dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# 1. 生成Q、K、V,并拆成多头
qkv = self.qkv_proj(x) # → (batch_size, seq_len, 3*embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # → (3, batch_size, n_heads, seq_len, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# 2. 计算注意力权重 + 缩放 + dropout
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.attn_dropout(attn_probs)
# 3. 加权求和V + 拼接多头 + 输出投影
context = torch.matmul(attn_probs, v) # → (batch_size, n_heads, seq_len, head_dim)
context = context.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
return self.out_proj(context)
class MLPBlock(nn.Module):
"""
ViT中的MLP块:GELU激活 + 隐藏层扩张4倍
"""
def __init__(self, embed_dim=768, mlp_dim=3072, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, mlp_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(mlp_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
标准Transformer编码器层:Pre-LN架构(原论文采用)
Pre-LN:先做LayerNorm,再过残差,更稳定
"""
def __init__(self, embed_dim=768, n_heads=12, mlp_dim=3072, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.mhsa = MultiHeadSelfAttention(embed_dim, n_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.mlp = MLPBlock(embed_dim, mlp_dim, dropout)
def forward(self, x):
# Pre-LN + 残差
x = x + self.mhsa(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
# ----------------------------
# 2. ViT完整模型
# ----------------------------
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dim=768, depth=12, n_heads=12, mlp_dim=3072, dropout=0.1):
super().__init__()
# 1. Patch Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# 2. Class Token:可学习的全局分类向量
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 3. 可学习位置编码:无硬编码正弦,让模型自己学
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.pos_dropout = nn.Dropout(dropout)
# 4. 堆叠Transformer编码器
self.encoder = nn.Sequential(*[
TransformerEncoderLayer(embed_dim, n_heads, mlp_dim, dropout)
for _ in range(depth)
])
# 5. 分类头
self.ln_head = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# 6. 初始化权重(原论文用截断正态分布)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
batch_size = x.size(0)
# 步骤1:Patch嵌入
x = self.patch_embed(x) # → (batch_size, 196, 768)
# 步骤2:拼接Class Token(每个样本补一个)
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # → (batch_size, 1, 768)
x = torch.cat([cls_tokens, x], dim=1) # → (batch_size, 197, 768)
# 步骤3:加位置编码 + dropout
x = x + self.pos_embed
x = self.pos_dropout(x)
# 步骤4:过编码器
x = self.encoder(x)
# 步骤5:取Class Token输出分类
x = self.ln_head(x[:, 0]) # 仅取第一个位置(Class Token)的输出
x = self.head(x)
return x
# ----------------------------
# 3. 测试模型
# ----------------------------
if __name__ == "__main__":
# 创建ViT-Base实例(对应原论文ViT-B/16)
vit_base = VisionTransformer(
img_size=224, patch_size=16,
embed_dim=768, depth=12, n_heads=12, mlp_dim=3072,
num_classes=1000
)
# 统计参数量
total_params = sum(p.numel() for p in vit_base.parameters())
print(f"ViT-B/16 总参数量: {total_params / 1e6:.1f}M") # 约86M
# 测试前向传播
dummy_img = torch.randn(1, 3, 224, 224)
output = vit_base(dummy_img)
print(f"输入形状: {dummy_img.shape}")
print(f"输出形状: {output.shape}") # 应该是(1, 1000)