语义分割:像素级图像理解与U-Net架构详解

引言

语义分割是计算机视觉中的核心任务之一——它不只是“识别图片里有什么”,还要为每个像素标注专属语义标签,精准勾勒出物体的轮廓与空间分布。从医学影像的器官分割,到自动驾驶的道路感知,这项技术都是不可或缺的基础。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:YOLO 家族实战 · 关键点检测 (Keypoints)


1. 语义分割基础概念

1.1 任务定义与相关视觉任务对比

语义分割的输入是一张高度为 H、宽度为 W、通道数为 C 的图像,输出则是一张高度为 H、宽度为 W、深度为 N 的分割图(N 为预定义类别数)。网络会为每个像素位置 (i,j) 分配一个类别标签,该标签来自集合 {1,2,...,N}。

它和其他视觉任务的核心区别如下:

任务类型输出形式核心目标
图像分类单类别标签识别整图内容
目标检测[类别, 边界框] 列表定位并识别对象
语义分割[像素 × 类别] 矩阵像素级分类(同类不区分实例)
实例分割[像素 × (类别, 实例ID)] 矩阵像素级分类 + 区分同类不同实例
全景分割同上(区分 "thing/stuff")统一的像素级场景理解

1.2 核心应用场景

语义分割的落地场景非常广泛:

  • 医疗影像:器官/肿瘤分割、病理切片分析
  • 自动驾驶:道路/车道线/障碍物分割
  • 遥感影像:土地利用分类、城市规划、环境监测
  • 智慧农业:作物/病虫害监测、产量预估
  • 机器人:环境理解、抓取定位
  • 时尚/娱乐:服装分割、虚拟试衣、影视后期抠图

2. 经典语义分割架构

2.1 FCN:全卷积网络的开山之作

FCN(Fully Convolutional Networks, 2015)是语义分割的里程碑,首次实现了端到端的像素级预测

核心贡献

  1. 全卷积设计:移除了分类网络的全连接层,替换为卷积层,支持任意尺寸输入;
  2. 反卷积上采样:使用转置卷积(Transposed Convolution)逐步恢复空间分辨率;
  3. 跳跃连接(Skip Connections):融合编码器的低层细节特征与解码器的高层语义特征,解决上采样后的细节丢失问题。

PyTorch 实现(FCN-8s)

import torch
import torch.nn as nn
import torch.nn.functional as F

class FCN8s(nn.Module):
    def __init__(self, num_classes=21):
        super().__init__()
        # 使用预训练VGG16作为编码器
        vgg16 = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
        self.features = vgg16.features
        
        # 分类头替换为卷积层
        self.fc_conv = nn.Sequential(
            nn.Conv2d(512, 4096, 7, padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, 4096, 1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, num_classes, 1)
        )
        
        # 跳跃连接的1x1卷积
        self.score_pool4 = nn.Conv2d(512, num_classes, 1)
        self.score_pool3 = nn.Conv2d(256, num_classes, 1)
        
        # 上采样层
        self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, bias=False)
        self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, bias=False)

    def forward(self, x):
        input_size = x.shape[2:]
        # 编码器特征提取
        pool3 = self.features[:17](x)  # pool3: 1/8 分辨率
        pool4 = self.features[17:24](pool3)  # pool4: 1/16 分辨率
        pool5 = self.features[24:](pool4)  # pool5: 1/32 分辨率
        
        # 高层特征上采样 + 跳跃连接
        score_fc = self.fc_conv(pool5)
        upscore2 = self.upscore2(score_fc)
        score_pool4 = self.score_pool4(pool4)
        fuse4 = upscore2 + score_pool4
        
        upscore4 = self.upscore2(fuse4)
        score_pool3 = self.score_pool3(pool3)
        fuse3 = upscore4 + score_pool3
        
        # 最终上采样到输入尺寸
        return F.interpolate(fuse3, size=input_size, mode='bilinear', align_corners=False)

2.2 U-Net:医学分割的“黄金标准”

U-Net(2015)最初为生物医学图像分割设计,因其对称的U型结构高效的跳跃连接,成为分割领域最常用的基础架构之一。

核心特点

  1. 对称编码器-解码器:编码器下采样提取语义,解码器上采样恢复空间分辨率;
  2. 拼接式跳跃连接:直接将编码器特征与解码器特征拼接(Concatenate),而非FCN的逐元素相加,保留更多细节;
  3. 小数据集友好:即使在少量标注数据上也能取得不错效果。

PyTorch 实现

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super().__init__()
        features = init_features
        
        # 编码器(下采样路径)
        self.enc1 = self._conv_block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = self._conv_block(features, features*2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = self._conv_block(features*2, features*4)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.enc4 = self._conv_block(features*4, features*8)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        # 瓶颈层
        self.bottleneck = self._conv_block(features*8, features*16)
        
        # 解码器(上采样路径)
        self.upconv4 = nn.ConvTranspose2d(features*16, features*8, 2, 2)
        self.dec4 = self._conv_block(features*16, features*8)
        self.upconv3 = nn.ConvTranspose2d(features*8, features*4, 2, 2)
        self.dec3 = self._conv_block(features*8, features*4)
        self.upconv2 = nn.ConvTranspose2d(features*4, features*2, 2, 2)
        self.dec2 = self._conv_block(features*4, features*2)
        self.upconv1 = nn.ConvTranspose2d(features*2, features, 2, 2)
        self.dec1 = self._conv_block(features*2, features)
        
        # 输出层
        self.outconv = nn.Conv2d(features, out_channels, 1)

    def forward(self, x):
        # 编码器
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # 瓶颈
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # 解码器 + 跳跃连接
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)
        
        return self.outconv(dec1)

    @staticmethod
    def _conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

2.3 DeepLab:空洞卷积与多尺度建模

DeepLab系列(2016-2018)的核心是空洞卷积(Atrous Convolution),在不降低分辨率的前提下扩大感受野,同时引入ASPP(空洞空间金字塔池化)捕获多尺度上下文信息。

核心组件:ASPP

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates=[6, 12, 18]):
        super().__init__()
        modules = []
        # 1x1 卷积分支
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ))
        # 不同空洞率的3x3卷积
        for rate in atrous_rates:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ))
        # 全局平均池化分支
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ))
        
        self.convs = nn.ModuleList(modules)
        # 特征融合投影
        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs)*out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        # 上采样全局池化结果
        res[-1] = F.interpolate(res[-1], size=res[0].shape[2:], mode='bilinear', align_corners=False)
        return self.project(torch.cat(res, dim=1))

3. 语义分割损失函数

分割任务常面临类别不平衡(如医学影像中肿瘤像素占比极低),因此除了标准交叉熵,还有以下专用损失:

损失函数适用场景核心思路
Cross Entropy类别平衡的数据集标准像素级分类损失
Dice Loss前景稀疏/小目标分割优化预测与标签的重叠率
Focal Loss类别不平衡 + 难分样本多降低易分样本权重,关注难例
Lovász Loss直接优化IoU指标IoU的平滑近似

常用损失代码实现

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs).view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        return (self.alpha * (1-pt)**self.gamma * ce_loss).mean()

# 组合损失:平衡交叉熵和Dice
class CombinedLoss(nn.Module):
    def __init__(self, weight_ce=1.0, weight_dice=1.0):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice

    def forward(self, inputs, targets):
        return self.weight_ce * self.ce(inputs, targets) + self.weight_dice * self.dice(inputs, targets)

4. 数据预处理与增强

分割任务的关键是图像与掩码必须同步变换,推荐使用 Albumentations 库(内置同步变换支持)。

专用数据增强策略

import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
    A.OneOf([A.OpticalDistortion(), A.GridDistortion()], p=0.3),
    A.OneOf([A.CLAHE(), A.RandomBrightnessContrast()], p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

自定义数据集类

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.imgs = os.listdir(img_dir)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        mask_path = os.path.join(self.mask_dir, self.imgs[idx])
        img = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path), dtype=np.int64)
        
        if self.transform:
            aug = self.transform(image=img, mask=mask)
            img, mask = aug["image"], aug["mask"]
        return img, mask

5. 模型训练与评估

核心训练流程

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for imgs, masks in tqdm(loader):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_miou = 0.0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        total_loss += loss.item() * imgs.size(0)
        # 计算mIoU
        preds = torch.argmax(outputs, dim=1)
        total_miou += compute_miou(preds, masks, num_classes=21) * imgs.size(0)
    return total_loss / len(loader.dataset), total_miou / len(loader.dataset)

核心评估指标:mIoU

def compute_miou(preds, targets, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_cls = preds == cls
        target_cls = targets == cls
        intersection = (pred_cls & target_cls).sum().item()
        union = (pred_cls | target_cls).sum().item()
        if union == 0:
            continue
        ious.append(intersection / union)
    return np.mean(ious) if ious else 0.0

6. 现代分割架构趋势

  1. Transformer 赋能:SegFormer、Swin-Unet、TransUNet 等混合/纯Transformer架构,在长距离建模上更具优势;
  2. 实时分割:BiSeNet、DFANet、Fast-SCNN 等轻量架构,平衡速度与精度,适配移动端/自动驾驶场景;
  3. 大一统模型:如 Mask2Former,统一语义/实例/全景分割任务。

语义分割是深度学习视觉的进阶任务,建议先掌握CNN基础与图像分类。入门可从U-Net + 小型医学/遥感数据集开始,再逐步探索DeepLab、SegFormer等进阶架构。

7. 总结

语义分割的核心是像素级分类,经典架构的演进围绕「恢复空间分辨率」和「融合多尺度信息」展开:

  • FCN 开创全卷积与跳跃连接;
  • U-Net 以对称U型结构成为通用基础;
  • DeepLab 引入空洞卷积与ASPP解决多尺度问题。

🔗 扩展阅读