SRGAN详解:超分辨率生成对抗网络原理与PyTorch实现

想象你翻出一张10年前320×240像素的毕业照,手指一放大,人脸糊成马赛克,黑板上的粉笔字完全认不出来。双三次插值之类的方法只能给你一种“模模糊糊的平滑感”,但2017年Ledig等人提出的SRGAN(Super‑Resolution GAN)却能给你“回忆的清晰感”。它第一次把生成对抗网络引入超分辨率任务,让图像放大从“像素填充”跨越到了“细节重建”。


1. SRGAN概述

1.1 传统方法的痛点

在SRGAN出现之前,主流超分方法(比如SRCNN)大都靠最小化均方误差(MSE)来训练。这样虽然能在PSNR这类数值指标上拿到高分,但图像看起来总像“被磨了皮”——关键的头发丝、皮肤纹理、建筑边缘等高频细节都丢失了,视觉上很不自然。

1.2 两个核心创新

1. **对抗性损失(Adversarial Loss)**:让生成器造假,让判别器打假。两者反复博弈,逼着生成器学会画出“能骗过人眼”的纹理细节。 2. **感知损失(Perceptual Loss)**:不再傻傻地逐像素比较,而是用预训练好的VGG网络提取**深层语义特征**,让生成的图和真实高分辨率图在“看起来像不像”这一层面更接近。

1.3 主要优势

  • 视觉真实感远超传统插值或纯CNN方法
  • 4倍甚至更高倍率放大时依然能重建出可信的细节
  • 架构可以迁移到医学影像、卫星遥感、视频增强等领域

2. 核心架构:三组件协同

SRGAN不是一个孤零零的网络,而是由生成器、判别器、VGG感知损失网络三位一体构成的。

2.1 生成器:低清→高清的魔术棒

生成器采用16个残差块(SRResNet骨架)+ PixelShuffle上采样。残差块负责深层特征提取,并有效防止梯度消失;PixelShuffle则是一种优雅的亚像素卷积上采样方式,专门用来避免棋盘伪影。

关键组件代码(精简)

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """SRGAN残差块:跳跃连接防止梯度消失"""
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),  # PReLU比ReLU更适合超分任务(减少死神经元)
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)   # 局部跳跃连接

class UpsampleBlock(nn.Module):
    """PixelShuffle上采样:避免棋盘伪影"""
    def __init__(self, in_channels, up_scale=2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * up_scale**2, 3, padding=1),
            nn.PixelShuffle(up_scale),
            nn.PReLU(),
        )

    def forward(self, x):
        return self.block(x)

完整生成器(代码折叠)

点击查看完整Generator
class Generator(nn.Module):
    """SRGAN生成器:默认4倍放大"""
    def __init__(self, scale_factor=4, num_res_blocks=16):
        super().__init__()
        self.scale_factor = scale_factor
        
        # 1. 初始低级特征提取
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4),
            nn.PReLU()
        )
        
        # 2. 16层残差块,深度提取高级特征
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_res_blocks)]
        )
        
        # 3. 中间卷积+全局跳跃连接(保留低频结构信息)
        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # 4. PixelShuffle上采样(每次2倍,4倍需2次)
        self.upsample = nn.Sequential(
            *[UpsampleBlock(64, 2) for _ in range(int(scale_factor/2))]
        )
        
        # 5. 输出RGB图像(Tanh归一化到[-1,1])
        self.last_conv = nn.Sequential(
            nn.Conv2d(64, 3, 9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        out1 = self.first_conv(x)
        out = self.res_blocks(out1)
        out = self.mid_conv(out)
        out = out1 + out          # 全局跳跃连接:把浅层信息直接送到深层
        out = self.upsample(out)
        out = self.last_conv(out)
        return out

2.2 判别器:真假图像的判官

判别器本质上是一个8层卷积网络,交替使用步长为1和步长为2的卷积来逐步提取特征,最后接全局平均池化和一个分类头,输出0~1的置信度(0=生成图/假,1=真实高分辨率图/真)。

点击查看完整Discriminator
class Discriminator(nn.Module):
    """SRGAN判别器:二分类,判断图像是真实高清图还是生成图"""
    def __init__(self, input_shape=(3, 96, 96)):
        super().__init__()
        self.input_shape = input_shape
        
        def conv_block(in_f, out_f, stride=1, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 3, stride, padding=1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # 特征提取主干
        self.backbone = nn.Sequential(
            *conv_block(3, 64, stride=1, norm=False),
            *conv_block(64, 64, stride=2),
            *conv_block(64, 128, stride=1),
            *conv_block(128, 128, stride=2),
            *conv_block(128, 256, stride=1),
            *conv_block(256, 256, stride=2),
            *conv_block(256, 512, stride=1),
            *conv_block(512, 512, stride=2),
        )

        # 全局平均池化 + 二分类头
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        feats = self.backbone(img)
        feats = self.pool(feats).flatten(1)
        return self.classifier(feats)

3. 灵魂:损失函数设计

SRGAN的损失由两部分加权组成:内容损失(像素损失+感知损失)对抗损失。其中,感知损失是让图像“看起来真实”的关键所在。

3.1 内容损失:像素匹配 + 感知匹配

先加载一个冻结参数的VGG19网络,用它来提取图像的高层语义特征。内容损失 = 很小权重的像素MSE + 大权重的感知特征MSE,这样网络既能保证整体结构不会跑偏,又能专心画出逼真的高频纹理。

import torchvision.models as models

class VGG19Extractor(nn.Module):
    """冻结预训练VGG19,提取特征用于感知损失计算"""
    def __init__(self, layer_idx=35):  # 取VGG19的第35层(ReLU激活后)
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.extractor = nn.Sequential(*list(vgg.children())[:layer_idx+1])
        
        for p in self.extractor.parameters():
            p.requires_grad = False
        
        # ImageNet的归一化参数
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, x):
        # 生成器输出在[-1,1]范围,先映射到[0,1],再用ImageNet均值标准差归一化
        x = (x + 1) / 2
        x = (x - self.mean) / self.std
        return self.extractor(x)

def content_loss(sr, hr, vgg, mse_w=0.01):
    """内容损失 = 小权重像素MSE + 大权重感知损失"""
    pixel_loss = F.mse_loss(sr, hr)
    sr_feats = vgg(sr)
    hr_feats = vgg(hr)
    percept_loss = F.mse_loss(sr_feats, hr_feats)
    return mse_w * pixel_loss + percept_loss

3.2 对抗损失:让判别器“难辨真假”

原始GAN论文用的是交叉熵,这里换成LSGAN(最小二乘GAN)损失,可以缓解梯度消失,训练更稳定。

def adv_loss_g(generator_out):
    """生成器损失:希望判别器认为生成的图是真的(标签为1)"""
    return F.mse_loss(generator_out, torch.ones_like(generator_out))

def adv_loss_d(real_out, fake_out):
    """判别器损失:真图标1,假图标0"""
    real_l = F.mse_loss(real_out, torch.ones_like(real_out))
    fake_l = F.mse_loss(fake_out, torch.zeros_like(fake_out))
    return (real_l + fake_l) / 2

4. 训练策略:两阶段更稳

如果一开始就让生成器直接和判别器对抗,生成器输出的图太假,判别器瞬间学会“打假”,生成器得到的梯度会消失,训练很快就停滞了。通常的做法是**先预训练生成器,再引入对抗训练**。

阶段1:预训练生成器(只用MSE像素损失)

这个阶段实际上是在训练一个SRResNet网络,目标就是让放大后的图像尽量接近真实高分辨率图(按像素均方误差)。收敛快,训练稳。

def pretrain_gen(generator, dataloader, device, epochs=50, lr=1e-4):
    opt = torch.optim.Adam(generator.parameters(), lr=lr)
    mse = nn.MSELoss()
    generator.train()
    
    for epoch in range(epochs):
        for batch_idx, (hr, lr) in enumerate(dataloader):
            hr, lr = hr.to(device), lr.to(device)
            opt.zero_grad()
            sr = generator(lr)
            loss = mse(sr, hr)
            loss.backward()
            opt.step()
            if batch_idx % 100 == 0:
                print(f"Pretrain E{epoch} B{batch_idx} | MSE Loss: {loss:.4f}")
    torch.save(generator.state_dict(), "srresnet_pretrain.pth")

阶段2:对抗训练(加载预训练权重)

核心循环逻辑通常是:交替训练判别器和生成器,比如先更新一次判别器,再更新一次生成器。这样可以避免一方压倒另一方,维持动态平衡。


5. 快速上手:用SRGAN修复老照片

把模糊的老照片放大并修复细节,其实只需要加载一个预训练好的生成器,再写几行预处理代码。

from PIL import Image
import torchvision.transforms as transforms

def enhance_old_photo(img_path, generator_path, device, scale=4):
    # 加载预训练生成器
    gen = Generator(scale_factor=scale).to(device)
    gen.load_state_dict(torch.load(generator_path, map_location=device))
    gen.eval()
    
    # 预处理:转tensor并归一化到[-1,1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
    ])
    
    img = Image.open(img_path).convert('RGB')
    lr = transform(img).unsqueeze(0).to(device)
    
    # 推理
    with torch.no_grad():
        sr = gen(lr)
    
    # 后处理:反归一化回[0,1],再转PIL图像
    sr = (sr.squeeze(0).cpu() + 1) / 2
    sr = torch.clamp(sr, 0, 1)
    return transforms.ToPILImage()(sr)

6. 发展趋势与挑战

主要变体

  • ESRGAN:把残差块换成Residual‑in‑Residual Dense Block(RRDB),去掉了BN层,同时引入相对论GAN(Relativistic GAN)来强化细节。
  • Real‑ESRGAN:用纯合成数据训练,大幅提升在真实低质量图片上的泛化能力,现在已经是很多图像增强工具背后的引擎。

现存挑战

  • 推理速度偏慢,移动端部署通常需要结合量化和剪枝等加速手段。
  • 偶尔会生成“伪真实细节”——比如把原本模糊的皮肤斑点错误地画成雀斑,这在一些对精度要求极高的场景中仍是难题。

总结

SRGAN是超分辨率领域从“追求数值指标”到“追求视觉真实”的里程碑。它用残差网络、VGG感知损失和对抗训练这三板斧,实现了高质量的图像放大。后续的ESRGAN、Real‑ESRGAN等变体在此基础上不断进化,如今已经在老照片修复、视频增强、游戏纹理放大等真实场景中发挥了巨大作用。


相关教程

🔗 扩展阅读