SRGAN (Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)

一、 前言:从“放大”到“还原”

在图像识别与处理的进阶道路上,SRGAN (Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network) 是一座里程碑。它首次将 生成对抗网络 (GAN) 引入超分辨率领域,解决了传统方法图像模糊、细节缺失的问题。
传统的图像放大技术(如双三次插值)只是在像素间进行数学填充,结果往往是边缘模糊、质感全无。而 SRGAN 的出现,让机器学会了“脑补”细节。

它不仅能把一张低分辨率(LR)的小图放大 4 倍变成高分辨率(HR)的大图,还能在放大过程中还原出复杂的纹理(如头发丝、皮肤毛孔、建筑物的缝隙)。它不再是简单的图像处理,而是一种基于深度学习的像素重构


二、 概述:SRGAN 的核心突破

在 SRGAN 之前,主流模型(如 SRCNN)主要通过最小化 均方误差 (MSE) 来训练。虽然这能获得较高的峰值信噪比(PSNR),但结果往往过于平滑,缺乏视觉上的真实感。

SRGAN 提出了两个核心创新:

  1. 对抗性损失 (Adversarial Loss):引入 GAN 架构,让生成器与判别器博弈,逼迫生成器产生更加真实的纹理。
  2. 感知损失 (Perceptual Loss):不再对比像素层面的差异,而是对比图像在经过预训练网络(如 VGG)提取后的“深层特征”是否一致。

三、 深度讲解:双网博弈结构

SRGAN 由两个相互竞争的子网络组成,其逻辑非常精妙:
核心组件:残差块 (Residual Block)
生成器和判别器都大量使用了残差块,它能让网络变得很深而不退化。

import torch
import torch.nn as nn
from torchvision.models import vgg19

# 基础残差块: Conv + BN + PReLU + Conv + BN
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU() # SRGAN 推荐使用 PReLU
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        # 核心:将输入直接加到输出上 (跳跃连接)
        return x + residual

1. 生成器网络 (Generator)

生成器采用深度残差网络(ResNet)作为基础。

  • 输入:一张低分辨率图像。
  • 过程:通过一系列具有跳跃连接的残差块提取特征,最后利用 子像素卷积层 (PixelShuffle) 进行上采样。
  • 目标:创造出一张欺骗过判别器眼睛的、细节丰富的超分图像。
class Generator(nn.Module):
    def __init__(self, scale_factor=4):
        super(Generator, self).__init__()
        
        # 1. 初始卷积层 (提取低级特征)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        # 2. 残差层 (共16个残差块,深层特征提取)
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(16)]
        )
        
        # 3. 中间卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # 4. 上采样层 (关键:PixelShuffle 像素复用,放大图像)
        # 如果放大 4 倍,需要两个 2 倍的上采样块
        if scale_factor == 4:
            self.upsample = nn.Sequential(
                # 上采样 1: 64 -> 256 -> PixelShuffle -> 64 (尺寸增加2倍)
                nn.Conv2d(64, 256, kernel_size=3, padding=1),
                nn.PixelShuffle(2), 
                nn.PReLU(),
                # 上采样 2: 64 -> 256 -> PixelShuffle -> 64 (尺寸再次增加2倍)
                nn.Conv2d(64, 256, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.PReLU()
            )
        
        # 5. 输出层 (3通道 RGB)
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh() # 将输出像素值归一化到 [-1, 1]
        )

    def forward(self, x):
        # x shape: (Batch, 3, LR_H, LR_W)
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out = self.conv2(out)
        
        # 核心:将第一层特征加到残差层的输出上
        out = out1 + out 
        
        out = self.upsample(out)
        out = self.conv3(out)
        # 最终输出 shape: (Batch, 3, HR_H*4, HR_W*4)
        return out

2. 判别器网络 (Discriminator)

判别器是一个典型的二分类 CNN。

  • 输入:生成的超分图像或真实的原始高清图像。
  • 目标:尽可能准确地分辨出哪张是人工生成的,哪张是相机拍摄的。
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # 定义基础卷积块: Conv + BN + LeakyReLU
        def discriminator_block(in_filters, out_filters, stride, normalize):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # 构建全卷积网络
        self.model = nn.Sequential(
            *discriminator_block(3, 64, stride=1, normalize=False), # 输入层
            *discriminator_block(64, 64, stride=2, normalize=True),
            *discriminator_block(64, 128, stride=1, normalize=True),
            *discriminator_block(128, 128, stride=2, normalize=True),
            *discriminator_block(128, 256, stride=1, normalize=True),
            *discriminator_block(256, 256, stride=2, normalize=True),
            *discriminator_block(256, 512, stride=1, normalize=True),
            *discriminator_block(512, 512, stride=2, normalize=True),
        )

        # 最终分类头
        self.classifier = nn.Sequential(
            nn.Linear(512 * 14 * 14, 1024), # 假设输入 HR 图像为 224x224,经过多次 stride=2 后变为 14x14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1), # 输出真假概率
            nn.Sigmoid() # 归一化到 [0, 1]
        )

    def forward(self, img):
        # img shape: (Batch, 3, HR_H, HR_W)
        feature_maps = self.model(img)
        # 展平
        feature_maps = feature_maps.view(feature_maps.size(0), -1)
        validity = self.classifier(feature_maps)
        # 最终输出 shape: (Batch, 1) -> 接近1为真,接近0为假
        return validity

3. 感知损失函数 (The Secret Sauce)

这是 SRGAN 看起来比其他模型更清晰的关键。它计算两个图像在 VGG-19 网络中间层的特征映射差异。模型不再死磕每一个像素点是否完全一样,而是要求“感觉上”两张图呈现的内容特征是一致的。

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        # 加载预训练的 VGG19
        vgg19_model = vgg19(pretrained=True)
        
        # 提取 VGG19 的前 18 层 (通常使用第 3 个卷积块的第 4 个卷积层特征,即 3_4 layer)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:18])
        
        # 冻结参数,不参与训练
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, img):
        # img shape: (Batch, 3, HR_H, HR_W)
        # 注意:VGG 提取特征前需要将图像标准化
        return self.vgg19_54(img)

四、 应用场景:超分辨率在干什么?

SRGAN 及其变体(如 ESRGAN)在工业界有着极高的应用价值:

  • 老照片/影视修复:将几十年前的低清胶片或监控录像提升至 4K 画质。
  • 医学影像增强:提升超声或 MRI 图像的清晰度,帮助医生发现更细微的病灶。
  • 卫星遥感:增强卫星拍摄的地表图像,使其能够看清地面车辆或建筑细节。
  • 移动端显示:在节省带宽的情况下,传输低分辨率视频,并在用户手机端实时进行超分还原。

五、 总结与局限性

SRGAN 的地位: 它开创了追求“视觉感知力(Perceptual Quality)”而非单纯“数学准确度(PSNR)”的先河。

注意点: 虽然 SRGAN 生成的图片非常“好美”,但由于 GAN 的特性,它有时会产生一些虚假的伪影(即机器脑补出了不存在的纹理)。在对严谨性要求极高的场景(如法庭取证或自动驾驶避障)中,需要谨慎使用。