生成式人工智能实战 | 生成对抗网络
- 0. 前言
- 1. 生成对抗网络
- 2. 模型构建
- 2.1 生成器
- 2.2 判别器
- 3. 模型训练
- 3.1 数据加载
- 3.2 训练流程
0. 前言
生成对抗网络 (Generative Adversarial Networks
, GAN
) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN
的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN
网络的原理并使用 PyTorch
实现 GAN
。
1. 生成对抗网络
生成对抗网络 (Generative Adversarial Networks
, GAN
) 包含两个网络:生成网络( Generator
,也称生成器)和判别网络( discriminator
,也称判别器)。在 GAN
网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像:
- 生成器 G ( z ; θ G ) G(z;θ_G) G(z;θG) 接受噪声 z ∼ p z z∼p_z z∼pz,学习映射到数据空间,以“欺骗”判别器
- 判别器 D ( x ; θ D ) D(x;θ_D) D(x;θD) 输出样本 x x x 属于真实数据的概率,旨在区分真实与生成数据
两者通过以下最小–最大化 (minimax
) 目标函数进行博弈:
m i n G m a x D V ( D , G ) = E x ∼ p d a t a [ l o g D ( x ) ] + E z ∼ p z [ l o g ( 1 − D ( G ( z ) ) ) ] \underset G {min} \underset D {max} V(D,G)=\mathbb E_{x∼p_{data}}[logD(x)]+\mathbb E_{z∼p_z}[log(1−D(G(z)))] GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN
中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:
在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN
的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:
- 生成网络的训练过程:冻结判别网络权重,生成网络以噪声
z
作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像 - 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像
重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。
2. 模型构建
2.1 生成器
生成器由若干全连接层与 LeakyReLU
激活构成,最后用 Tanh
将输出映射至 [−1,1]
范围内:
# 定义生成器 G
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=10):super().__init__()self.net = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh() # 输出像素映射到 [-1,1])def forward(self, z):return self.net(z).view(-1,1,28,28)
2.2 判别器
判别器使用全连接层与 LeakyReLU
激活,末端使用 Sigmoid
激活函数,输出一个标量真值估计:
# 定义判别器 D,对输入图片输出真伪概率
class Discriminator(nn.Module):def __init__(self, img_dim=28*28):super().__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(img_dim, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)
3. 模型训练
接下来,使用 MNIST
数据集训练 GAN
模型。
3.1 数据加载
将 MNIST
像素值归一化到生成器 Tanh
输出所需的 [−1,1]
区间:
# 加载并归一化 MNIST 数据集
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # 映射到 [-1,1]
])
train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
3.2 训练流程
首先,初始化模型、优化器与损失函数:
import torch
import torch.optim as optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
z_dim = 20
G = Generator(z_dim).to(device)
D = Discriminator().to(device)opt_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()
训练模型 50
个 epoch
:
epochs = 50for epoch in range(epochs):for real, _ in train_loader:real = real.to(device)batch_size = real.size(0)real_labels = torch.ones(batch_size, 1, device=device)fake_labels = torch.zeros(batch_size, 1, device=device)# 训练判别器# 在真实样本上进行训练D_real = D(real)loss_D_real = loss_fn(D_real, real_labels)opt_D.zero_grad()loss_D_real.backward()opt_D.step()# 在虚假样本上进行训练z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake.detach())loss_D_fake = loss_fn(D_fake, fake_labels)opt_D.zero_grad()loss_D_fake.backward()opt_D.step()d_loss = (loss_D_real + loss_D_fake) / 2# 训练生成器z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake)loss_G = loss_fn(D_fake, real_labels) # 生成器希望D认为它生成的是真的opt_G.zero_grad()loss_G.backward()opt_G.step()print(f"Epoch [{epoch+1}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {loss_G.item():.4f}")
使用训练后的模型生成伪造数据:
# 采样生成图片并显示
import matplotlib.pyplot as pltG.eval()
with torch.no_grad():z = torch.randn(16, z_dim, device=device)fake_images = G(z).cpu()fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i].squeeze().reshape(28, 28), cmap='gray')ax.axis('off')
plt.tight_layout()
plt.show()