数据增强 (Data Augmentation):翻转、裁剪、遮挡提升泛化性

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:手写数字识别 (MNIST) 实战 · 迁移学习 (Transfer Learning)


1. 为什么需要数据增强?

问题:
  - 训练数据有限
  - 模型容易过拟合
  - 泛化性差

解决方案:
  - 数据增强:从有限数据生成更多变化
  - 让模型学到更鲁棒的特征

2. 常见增强方法

import torchvision.transforms as transforms
from PIL import Image

# 基础增强
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 50% 概率水平翻转
    transforms.RandomVerticalFlip(p=0.5),    # 50% 概率竖直翻转
    transforms.RandomRotation(15),            # 随机旋转 ±15 度
    transforms.RandomCrop(224, padding=4),    # 随机裁剪
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 应用增强
img = Image.open("photo.jpg")
augmented_img = transform(img)

3. 高级增强技术

3.1 Mixup

"""
Mixup:混合两张图像和标签

y_mixed = λ × y1 + (1-λ) × y2
x_mixed = λ × x1 + (1-λ) × x2
"""

import torch
import numpy as np

def mixup(x1, y1, x2, y2, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    x_mixed = lam * x1 + (1 - lam) * x2
    y_mixed = lam * y1 + (1 - lam) * y2
    return x_mixed, y_mixed

3.2 CutMix

"""
CutMix:随机裁剪并混合两张图像

比 Mixup 更有效,保留了局部结构
"""

def cutmix(x1, y1, x2, y2, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size, _, h, w = x1.size()
    
    cut_ratio = np.sqrt(1. - lam)
    cut_h = int(h * cut_ratio)
    cut_w = int(w * cut_ratio)
    
    cx = np.random.randint(0, w)
    cy = np.random.randint(0, h)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, w)
    bby1 = np.clip(cy - cut_h // 2, 0, h)
    bbx2 = np.clip(cx + cut_w // 2, 0, w)
    bby2 = np.clip(cy + cut_h // 2, 0, h)
    
    x_mixed = x1.clone()
    x_mixed[:, :, bby1:bby2, bbx1:bbx2] = x2[:, :, bby1:bby2, bbx1:bbx2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (h * w))
    y_mixed = lam * y1 + (1 - lam) * y2
    
    return x_mixed, y_mixed

3.3 RandAugment

"""
RandAugment:随机选择增强操作

自动搜索最优的增强策略
"""

from torchvision.transforms import AutoAugment, AutoAugmentPolicy

transform = transforms.Compose([
    AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
    transforms.ToTensor(),
])

4. 小结

数据增强三层次:

基础:翻转、旋转、裁剪、颜色抖动
进阶:Mixup、CutMix
自动:AutoAugment、RandAugment

效果:
- 基础增强:准确率 +1-2%
- Mixup/CutMix:准确率 +2-3%
- AutoAugment:准确率 +3-5%

💡 记住:数据增强是提升模型性能最简单有效的方法。在数据有限的情况下,增强往往比增加模型深度更有效。


🔗 扩展阅读