语义分割 (Semantic Segmentation):U-Net 架构、像素级理解

📂 所属阶段:第三阶段 — 核心视觉任务(进阶篇)
🔗 相关章节:YOLO 家族实战 · 关键点检测 (Keypoints)


1. 语义分割基础

语义分割 = 逐像素分类

输出:每个像素的类别标签

应用:
- 医学影像:肿瘤分割
- 自动驾驶:道路分割
- 遥感:土地利用分类

2. U-Net 架构

import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # 编码器
        self.enc1 = self.conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # 瓶颈
        self.bottleneck = self.conv_block(128, 256)
        
        # 解码器
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # 输出
        self.final = nn.Conv2d(64, out_channels, 1)
    
    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        # 编码
        enc1 = self.enc1(x)
        x = self.pool1(enc1)
        enc2 = self.enc2(x)
        x = self.pool2(enc2)
        
        # 瓶颈
        x = self.bottleneck(x)
        
        # 解码
        x = self.upconv2(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)
        
        x = self.upconv1(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)
        
        x = self.final(x)
        return x

3. 训练分割模型

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

model = UNet(in_channels=3, out_channels=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(10):
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

4. 小结

语义分割三要素:

1. 编码器:提取特征
2. 瓶颈:最深层
3. 解码器:恢复分辨率

U-Net 特点:
- 跳跃连接:保留细节
- 对称结构:编码-解码
- 适合医学影像

2026 年推荐:
- 快速:SegNet
- 精准:DeepLab
- 实时:BiSeNet

💡 记住:U-Net 的跳跃连接是关键。它让模型既能看到全局又能保留细节。


🔗 扩展阅读