GAN详解:生成对抗网络原理与PyTorch实现

如果说传统CNN是让计算机"看懂"图像,那么GAN就是让计算机"学会创造"——它由Ian Goodfellow等人在2014年提出,用博弈论纳什均衡开创了无监督生成的新篇章。


1. GAN核心:造假者与鉴定师的博弈

1.1 生动比喻

  • 生成器 G:技艺精湛的造假者,输入随机噪声,输出「以假乱真」的样本
  • 判别器 D:经验丰富的鉴定师,输入样本,输出「为真」的概率

两者零和博弈:造假者不断精进,鉴定师同步升级,最终纳什均衡——生成样本与真实数据分布几乎一致,判别器输出恒为0.5。

1.2 核心优势

无需标注数据即可学习;可生成图像/音频/文本;图像生成质量远超传统生成模型(如VAE)。


2. 架构与工作原理

2.1 极小极大博弈

GAN的训练可以看作一个相互对抗的优化过程:

  • 判别器 D 希望成为最挑剔的鉴定师:对真实图像输出接近1的高分,对生成的假图像输出接近0的低分。它的目标是最大化自己正确分类的能力。
  • 生成器 G 的目标则完全相反:它想方设法让判别器给自己生成的假图像打出高分(接近1),即最小化自己被识破的概率。

训练时,两者交替更新。每轮先固定生成器,训练判别器提高鉴别能力;再固定判别器,训练生成器提高造假水准。随着迭代,两者能力螺旋上升,最终理想情况下,生成器产出的样本真假难辨,判别器无从判断,只能对所有样本输出0.5,达到博弈均衡。

2.2 基础组件

  • 输入噪声:生成器的输入是随机向量,通常从标准正态分布(均值为0、方差为1)中采样。
  • 生成器结构:利用转置卷积(Transposed Convolution)将低维噪声逐步上采样,最终映射为高维数据(如图像)。
  • 判别器结构:利用标准卷积层逐步下采样,提取特征,最后通过 Sigmoid 输出一个0到1之间的概率值。
  • 输出说明:判别器输出真假概率;生成器通常使用 Tanh 激活函数,将生成图像的像素值限制在[-1, 1]之间,以匹配数据预处理时的归一化范围。

3. DCGAN PyTorch快速实现

DCGAN是GAN的卷积化标准实现,通过规范的网络设计和训练技巧大幅提升了稳定性。

3.1 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

3.2 生成器

class DCGANGenerator(nn.Module):
    """输入100维噪声 -> 输出3×64×64归一化图像"""
    def __init__(self, nz=100, ngf=64, nc=3):
        super().__init__()
        self.layers = nn.Sequential(
            # nz -> ngf*8×4×4
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            # ngf*8×4×4 -> ngf*4×8×8
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # ngf*4×8×8 -> ngf×32×32
            nn.ConvTranspose2d(ngf*4, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # ngf×32×32 -> nc×64×64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.layers(z)

3.3 判别器

class DCGANDiscriminator(nn.Module):
    """输入3×64×64图像 -> 输出[0,1]置信度"""
    def __init__(self, nc=3, ndf=64):
        super().__init__()
        self.layers = nn.Sequential(
            # nc×64×64 -> ndf×32×32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf×32×32 -> ndf*2×16×16
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf*2×16×16 -> ndf*4×8×8
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf*4×8×8 -> 1×1×1
            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x).view(-1)

3.4 数据加载与训练

数据预处理

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1,1]
])

# 替换为你的本地数据集路径
dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

训练循环

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型、优化器、损失
G = DCGANGenerator().to(device)
D = DCGANDiscriminator().to(device)
criterion = nn.BCELoss()
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 固定噪声用于可视化
fixed_noise = torch.randn(16, 100, 1, 1, device=device)

num_epochs = 5
G_losses, D_losses = [], []

for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        real_label = torch.ones(batch_size, device=device)
        fake_label = torch.zeros(batch_size, device=device)

        ###########################
        # 训练判别器
        ###########################
        D.zero_grad()
        # 真实样本损失
        output_real = D(real_imgs)
        errD_real = criterion(output_real, real_label)
        errD_real.backward()
        # 生成样本损失
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_imgs = G(noise)
        output_fake = D(fake_imgs.detach())  # 冻结生成器梯度
        errD_fake = criterion(output_fake, fake_label)
        errD_fake.backward()
        # 更新判别器
        errD = errD_real + errD_fake
        optimizerD.step()

        ###########################
        # 训练生成器
        ###########################
        G.zero_grad()
        # 反向欺骗判别器(希望假样本被判别为真)
        output_G = D(fake_imgs)
        errG = criterion(output_G, real_label)
        errG.backward()
        # 更新生成器
        optimizerG.step()

        # 记录损失
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 每50步输出训练信息
        if i % 50 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}][Batch {i}/{len(dataloader)}] "
                  f"Loss_D: {errD:.4f} Loss_G: {errG:.4f} "
                  f"D(G(z)): {output_G.mean().item():.4f}")

    # 每轮可视化生成结果
    with torch.no_grad():
        fake_fixed = G(fixed_noise).detach().cpu()
    grid = utils.make_grid(fake_fixed, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.show()

训练监控提示D(G(z)) 表示判别器对生成样本给出的平均评分,理想情况下该值会逐渐趋近于0.5,说明生成器已经能够成功迷惑判别器。


4. 常见挑战与改进方向

4.1 主要挑战

  1. 模式崩坏:生成器只输出有限种类样本,缺乏多样性。
  2. 训练不稳定:损失震荡,判别器或生成器一方过强导致梯度消失/爆炸。
  3. 评估困难:缺乏绝对客观的生成质量指标(如FID、IS等仅作相对参考)。

4.2 经典改进方案

改进方案解决问题核心思路
WGAN模式崩坏/不稳定用Wasserstein距离代替传统GAN的目标度量
WGAN-GPWGAN梯度裁剪生硬用梯度惩罚代替直接裁剪网络权重
CycleGAN无配对图像翻译加入循环一致性损失,实现无需配对的风格转换
StyleGAN精细控制生成风格引入风格映射网络、AdaIN归一化,实现属性解耦控制

5. 实践建议

  1. 数据预处理:严格将图像归一化到[-1, 1],与生成器最后一层的 Tanh 激活相匹配。
  2. 优化器选择:固定使用 Adam,学习率设为 0.0002,beta1=0.5,这些参数在实践中表现稳健。
  3. 训练策略:交替训练判别器和生成器,尽量避免某一方碾压另一方。可尝试对判别器做更多次更新,之后再更新一次生成器。
  4. 监控手段:除了观察损失曲线,务必定期可视化生成样本——损失有时会撒谎,人眼检查图像质量才是最直接的。

总结

GAN用简单的博弈思想实现了惊人的生成效果,虽有训练挑战,但仍是AI创作领域的核心工具之一。通过本文的DCGAN实现,你可以快速上手GAN,后续可根据需求选择CycleGAN、StyleGAN等变体。


🔗 扩展阅读