GAN (Generative Adversarial Networks,生成对抗网络)

1. 前言

如果说传统的卷积神经网络(CNN)是让计算机“看懂”图像,那么 GAN 就是让计算机“学会创造”图像。GAN 的核心思想源于博弈论中的纳什均衡。

想象一个场景:

  • 生成器 (Generator):一个伪造钞票的罪犯,目标是做出让警察看不出真假的假币。
  • 判别器 (Discriminator):一名警察,目标是准确分辨出哪些是真币,哪些是假币。

在这个博弈过程中,罪犯(生成器)的伪造技术越来越精湛,警察(判别器)的鉴别能力也越来越敏锐。最终,生成器制造出的假币达到了“以假乱真”的程度,连警察也无法分辨。这就是 GAN 训练的终极目标。


2. 网络概述

GAN 由两个互相对抗的网络组成:

2.1 生成器 (G)

  • 输入:一个随机噪声向量 zz(通常服从高斯分布)。
  • 输出:一张伪造的图像 G(z)G(z)
  • 目标:尽可能提高图像质量,让判别器判定其为“真”。

2.2 判别器 (D)

  • 输入:一张真实的图像 xx 或生成器生成的图像 G(z)G(z)
  • 输出:一个概率值(0到1之间),表示输入图像是“真”的概率。
  • 目标:最大化分辨真假的能力。

2.3 目标函数(Loss Function)

GAN 的训练是一个 Min-Max 极大极小博弈问题,其公式如下: minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]

  • 对于 D 来说:希望 D(x)D(x) 趋近于 1,D(G(z))D(G(z)) 趋近于 0(最大化公式)。
  • 对于 G 来说:希望 D(G(z))D(G(z)) 趋近于 1,即让 1D(G(z))1-D(G(z)) 趋近于 0(最小化公式)。

3. 详细网络结构:PyTorch 实现

我们以最基础的 DCGAN (Deep Convolutional GAN) 架构为例,使用 PyTorch 实现一个简单的头像生成逻辑。

3.1 生成器 (Generator) 结构

生成器使用 转置卷积 (ConvTranspose2d) 将低维噪声矢量“放大”成高维图像。

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是 nz 维度的噪声,进入转置卷积
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 状态尺寸: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 状态尺寸: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 状态尺寸: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh() # 输出层使用 Tanh,将像素值归一化到 [-1, 1]
            # 状态尺寸: nc x 32 x 32
        )

    def forward(self, input):
        return self.main(input)

3.2 判别器 (Discriminator) 结构

判别器本质上是一个二分类器,使用步长卷积逐渐降低空间维度。

class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入尺寸: nc x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: ndf x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: (ndf*2) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid() # 最终输出 0~1 的概率
        )

    def forward(self, input):
        return self.main(input).view(-1)

4. 训练核心逻辑

GAN 的训练与普通网络不同,需要在一个循环内交替更新 D 和 G。

# 实例化
netG = Generator(nz=100, ngf=64, nc=3).to(device)
netD = Discriminator(nc=3, ndf=64).to(device)
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        # --- (1) 更新判别器 D: 最大化 log(D(x)) + log(1 - D(G(z))) ---
        netD.zero_grad()
        real_img = data[0].to(device)
        batch_size = real_img.size(0)
        label = torch.full((batch_size,), 1.0, device=device) # 真标签为1

        output = netD(real_img)
        errD_real = criterion(output, label)
        errD_real.backward()

        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_img = netG(noise)
        label.fill_(0.0) # 假标签为0
        output = netD(fake_img.detach()) # 注意要 detach,不更新 G
        errD_fake = criterion(output, label)
        errD_fake.backward()
        optimizerD.step()

        # --- (2) 更新生成器 G: 最大化 log(D(G(z))) ---
        netG.zero_grad()
        label.fill_(1.0) # G 希望 D 认为假币是真的
        output = netD(fake_img)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

5. 总结与建议

GAN 的训练痛点:

  1. 模式崩溃 (Mode Collapse):生成器只学会生成一种非常像真钱的假币(例如只生成一种人脸),失去了多样性。
  2. 不收敛:D 和 G 的力量不平衡导致梯度消失或爆炸。