迁移学习 (Transfer Learning):利用预训练模型快速构建高性能模型

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:数据增强 (Data Augmentation) · 目标检测理论


引言

你是否被这些现实问题困扰过?

  • 手头只有几百张私有标注图片,从零开始训练CNN,准确率连60%都不到;
  • 好不容易攒够了数据,GPU显存和时间却根本撑不住大规模模型的训练。

迁移学习(Transfer Learning) 就是为了解决这些痛点而生的现代深度学习利器。它的核心思想很简单:复用已经在大规模通用数据集(比如ImageNet的120万张图片)上训练好的视觉特征,再针对你的特定任务做少量的调整,就能快速获得一个性能优异的模型。


1. 核心概念与工作原理

1.1 为什么迁移学习这么有效?

CNN的不同层学习到的特征具有天然的通用性分层

网络层级学到的特征迁移价值
浅层(低层)边缘、角点、基础纹理通用度极高,几乎无需修改
中层形状组合、简单对象部件通用度较高,可微调少量
深层(顶层+FC)完整对象、语义分类(ImageNet类别)通用度低,必须替换或大幅微调

正是因为浅层和中层特征对几乎所有视觉任务都通用,我们只需调整那些与特定类别相关的深层部分,就能用很少的数据完成适应。

1.2 核心迁移策略(按保守程度排序)

根据你的数据量和计算资源,可以直接套用下面这张决策速查表:

条件组合推荐策略资源/时间消耗适用场景
小数据集(<1000)+ 低算力特征提取(仅训练新分类头)极低个人实验、快速原型验证
小数据集(<1000)+ 相似通用任务特征提取 → 微调最后1-2个CNN模块医学/自然小分类任务
中等数据集(1k-10k)+ 中算力分层微调(冻结前N层,微调后面层+新头)通用工业分类、数据集相似但规模够大
大数据集(>10k)+ 高算力全量微调(用小学习率训练所有层)高精度需求、跨领域适配后再优化

2. PyTorch实现迁移学习

2.1 准备工作:修正过时API + 基础配置

⚠️ 重要提示:torchvision 0.13+ 版本已经弃用了 pretrained=True,现在推荐使用更规范的 weights=预训练权重类

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.models import (
    ResNet50_Weights,
    VGG16_Weights,
    MobileNet_V3_Small_Weights
)

2.2 完整的特征提取 / 分层微调代码

下面以最常用的 ResNet50 + 私有图像分类 为例,编写了一个同时兼容特征提取和分层微调的函数:

def build_transfer_model(
    model_name: str = "resnet50",
    num_classes: int = 10,
    strategy: str = "feature_extract",  # 可选: feature_extract, fine_tune, full_tune
    freeze_layers: int = 4  # 仅fine_tune有效;ResNet50共有8个children,这里冻结前4个
):
    """
    构建迁移学习模型
    """
    # 1. 加载带预训练权重的模型
    if model_name == "resnet50":
        weights = ResNet50_Weights.IMAGENET1K_V1
        model = models.resnet50(weights=weights)
        in_features = model.fc.in_features
        head = "fc"
    elif model_name == "vgg16":
        weights = VGG16_Weights.IMAGENET1K_V1
        model = models.vgg16(weights=weights)
        in_features = model.classifier[6].in_features
        head = "classifier"
    elif model_name == "mobilenet_v3_small":
        weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        model = models.mobilenet_v3_small(weights=weights)
        in_features = model.classifier[3].in_features
        head = "classifier"
    else:
        raise ValueError(f"不支持的模型: {model_name}")

    # 2. 根据策略冻结参数
    if strategy == "feature_extract":
        # 冻结所有参数,只训练新分类头
        for param in model.parameters():
            param.requires_grad = False
    elif strategy == "fine_tune":
        # 只冻结前 freeze_layers 个模块
        children = list(model.children()) if hasattr(model, "children") else []
        for child in children[:freeze_layers]:
            for param in child.parameters():
                param.requires_grad = False
    # full_tune 不冻结任何参数

    # 3. 替换任务特定的分类头
    if model_name == "resnet50":
        model.fc = nn.Linear(in_features, num_classes)
    elif model_name == "vgg16":
        model.classifier[6] = nn.Linear(in_features, num_classes)
    elif model_name == "mobilenet_v3_small":
        model.classifier[3] = nn.Linear(in_features, num_classes)

    # 同时返回预训练模型要求的预处理方法,避免手动写错归一化参数
    return model, weights.transforms()

2.3 数据加载(注意!预处理必须和预训练一致)

刚刚 build_transfer_model 返回的预训练预处理可以直接复用,这样就能保证我们的数据预处理和当初训练ImageNet时完全一致:

def get_data_loaders(
    dataset_path: str = "./data",
    preprocess: transforms.Compose = None,
    batch_size: int = 32,
    train_augment: bool = True
):
    """
    获取训练/验证数据加载器
    """
    # 训练集:预训练预处理 + 可选的数据增强
    train_transforms = [preprocess]
    if train_augment:
        # 简单而有效的数据增强
        train_transforms.insert(0, transforms.RandomHorizontalFlip(p=0.5))
        train_transforms.insert(0, transforms.Resize(256))
        train_transforms.insert(1, transforms.RandomCrop(224))
    train_transform = transforms.Compose(train_transforms)

    # 验证/测试集:只使用预训练预处理,不做增强
    val_transform = preprocess

    # 加载ImageFolder格式的数据集(最常用:根目录下直接是类别子文件夹)
    train_dataset = datasets.ImageFolder(f"{dataset_path}/train", transform=train_transform)
    val_dataset = datasets.ImageFolder(f"{dataset_path}/val", transform=val_transform)

    # 构建DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, train_dataset.classes

2.4 训练循环(极简版,包含早停逻辑)

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 10,
    lr: float = 0.001,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    训练迁移学习模型
    """
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    # 只优化那些需要训练的参数(即 requires_grad=True 的部分)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    # 学习率调度器:验证准确率不再提升时自动降低学习率
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

    best_val_acc = 0.0
    best_model_state = model.state_dict()

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            train_total += labels.size(0)
            train_correct += preds.eq(labels).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # 验证阶段
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, preds = outputs.max(1)
                val_total += labels.size(0)
                val_correct += preds.eq(labels).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total
        scheduler.step(val_acc)  # 根据验证准确率调整学习率

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()

        # 打印训练日志
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    print(f"\n训练完成!最佳验证准确率: {best_val_acc:.4f}")
    model.load_state_dict(best_model_state)
    return model

3. 快速上手示例

把上面的函数拼起来,一个完整的迁移学习流程只需十几行代码:

if __name__ == "__main__":
    # 1. 构建模型(以特征提取为例)
    model, preprocess = build_transfer_model(
        model_name="resnet50",
        num_classes=5,  # 假设你的私有数据集有5个类别
        strategy="feature_extract"
    )

    # 2. 加载数据(确保数据集是ImageFolder格式:./data/train/类1, ./data/val/类1...)
    train_loader, val_loader, classes = get_data_loaders(
        dataset_path="./data",
        preprocess=preprocess,
        batch_size=32
    )
    print(f"类别列表: {classes}")

    # 3. 训练模型
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=0.001
    )

    # 4. 保存最佳模型
    torch.save(trained_model.state_dict(), "best_transfer_model.pth")

4. 最佳实践

为了让你的迁移学习之旅更顺利,这里总结了几条实战经验:

  1. 预处理必须一致:直接使用预训练权重自带的 transforms(),避免手动填写归一化均值、标准差时出错。
  2. 从保守策略开始:先用特征提取跑通整个流程,确认数据和代码没问题之后,再尝试分层微调或全量微调。
  3. 分层学习率(可选但强烈推荐):微调时,可以给不同层设置不同的学习率,例如新分类头用 1e-3,最后 1-2 个 CNN 模块用 1e-4,前面的层用 1e-5 或更小——这样能让底层通用特征保持稳定,高层特征逐渐适应新任务。
  4. 早停机制:避免不必要的过度训练,始终保存验证集上表现最好的模型。
  5. 数据增强不可少:即使只有几百张图片,也应当加入简单的数据增强(水平翻转、随机裁剪),这能有效抑制过拟合。
迁移学习是现代深度学习的基石之一。建议从最简单的特征提取开始练习,逐步深入微调技术。记住,在现在的项目实践中,从零开始训练一个完整的 CNN 模型已经很少见——除非你的任务和所有公开预训练数据集都完全无关,并且你拥有超大规模的数据。

扩展阅读