PyTorch生成式人工智能——ACGAN详解与实现

PyTorch生成式人工智能——ACGAN详解与实现

    • 0. 前言
    • 1. ACGAN 简介
      • 1.1 ACGAN 技术原理
      • 1.2 ACGAN 核心思想
      • 1.3 损失函数
    • 2. 模型训练流程
    • 3. 使用 PyTorch 构建 ACGAN
      • 3.1 数据处理
      • 3.2 模型构建
      • 3.3 模型训练
      • 3.4 模型测试
    • 相关链接

0. 前言

在生成对抗网络 (Generative Adversarial Network, GAN) 的众多变体中,ACGAN (Auxiliary Classifier GAN) 是一个非常经典且实用的条件生成模型。它的核心思想是:在判别器中除了保留“真假判别”这一任务外,额外加入一个辅助分类器,让判别器同时预测输入样本的类别。这样,生成器在训练时不仅需要“欺骗判别器”,还必须生成能够被正确分类的样本,从而在图像语义和类别可控性上得到显著提升。
这一改进让 ACGAN 能够在条件图像生成中表现出色,在复杂数据集上实现按类别生成的能力。相比于传统条件生成对抗网络 (Conditional GAN, cGAN) 简单地把标签拼接到输入,ACGAN 通过 “辅助分类监督” 提供了更细粒度的学习信号,使得生成器得到的梯度更加稳定和有意义。在本节中,将详细介绍 ACGAN 原理,并使用 PyTorch 构建 ACGAN 模型。

1. ACGAN 简介

1.1 ACGAN 技术原理

生成对抗网络 (Generative Adversarial Network, GAN) 的众多变体中,ACGAN (Auxiliary Classifier GAN) 能够从随机噪声中生成逼真的图像、文本甚至音乐。然而,传统的 GAN 有一个显著的局限性:缺乏对生成过程的精确控制。我们无法指定要生成“数字7”的图片还是一只“戴墨镜的猫”。
为了解决这个问题,条件生成对抗网络 (Conditional GAN, cGAN) 应运而生。它通过将类别标签信息同时注入生成器 (Generator) 和判别器 (Discriminator),实现了条件生成。但这仍然不够完美,cGAN 的判别器最终只输出一个“真/假”的概率,它并没有显式地告诉生成器它生成的图片在类别上是否正确。
ACGAN (Auxiliary Classifier GAN) 正是在 CGAN 的基础上,对判别器的任务进行了至关重要的扩展。它不仅判断真伪,还同时担任一个“分类器”的角色。这个简单的改变,极大地提升了生成图像的质量和多样性,尤其是在生成特定类别的图像时。

1.2 ACGAN 核心思想

ACGAN 的核心思想非常直观:为判别器增加一个辅助任务——对输入图像进行分类。其中,生成器的输入包括随机噪声向量 zzz 和目标类别标签 ccc;判别器的输出包括:

  • 一个源 (Source) 输出:一个标量概率,表示图像是来自真实数据分布的概率
  • 一个辅助类别 (Class) 输出:类别概率分布

通过引入这个辅助的分类任务,ACGAN 迫使判别器不仅要学习“什么样的图像是真实的”,还要学习“真实图像属于什么类别”。反过来,生成器为了欺骗这个更强大的判别器,也必须生成既逼真又类别分明的图像。

1.3 损失函数

损失函数包含两部分:

  • 源判别损失 (source loss),用来训练真假判别,通常使用二元交叉熵
  • 类别判别损失 (auxiliary classification loss),使用多元交叉熵(真实图像的类别为真实标签,生成图像的类别为生成器的条件标签

训练目标:

  • 判别器 D,最小化源判别损失(正确区分真实/虚假图像)并最小化类别判别损失(正确预测类别)
  • 生成器 G:生成图像以最大化判别器认为是“真实”的概率,并最小化判别器给出的类别预测与条件类的一致性

2. 模型训练流程

模型训练流程如下:

  1. 从真实数据中取一批数据 x,c{x,c}x,c
  2. 判别器更新:
    • 计算真实样本的源判别损失与类别判别损失
    • 用噪声和随机标签生成虚假样本 x~=G(z,c)\tilde x=G(z,c)x~=G(z,c),计算虚假样本的源判别损失与类别判别损失(可选)
    • 把这些损失加权后更新 D
  3. 生成器更新:
    • 用一批噪声与条件标签生成样本 x~\tilde xx~
    • 通过 D 计算源输出与辅助类别输出
    • 生成器的损失是希望源输出为“真实”,并希望辅助类别输出为生成时的条件标签
    • 更新 G

3. 使用 PyTorch 构建 ACGAN

接下来,使用 PyTorch 实现 ACGAN,并在 MNIST 数据集上进行训练生成手写数字。

3.1 数据处理

(1) 首先,导入所需库并设置超参数:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import os# 设置超参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
batch_size = 64
lr = 0.0002
num_epochs = 100
sample_interval = 400# 创建输出目录
os.makedirs("images", exist_ok=True)
os.makedirs("models", exist_ok=True)

(2) 加载 MNIST 数据集,将图像转换为张量,并将像素值从 [0,1] 归一化到 [-1,1] 范围:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=transform
)

(3) 构建数据加载器:

dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True
)

3.2 模型构建

(1) 首先,定义权重初始化函数:

def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

(2) 定义生成器。生成器接收随机噪声和类别标签作为输入,通过嵌入层将标签转换为与噪声相同维度的向量,然后将二者相乘融合,之后通过全连接层和转置卷积层逐步上采样,最终生成 28 x 28 的图像:

class Generator(nn.Module):def __init__(self, latent_dim, num_classes):super(Generator, self).__init__()# 将类别标签转换为嵌入向量self.label_emb = nn.Embedding(num_classes, latent_dim)self.init_size = 7  # 初始特征图大小self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 1, 3, stride=1, padding=1),nn.Tanh())def forward(self, noise, labels):# 将噪声和标签嵌入相乘gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return img

除了将类别标签转换为嵌入向量进行融合外,也可以直接使用标签的独热编码与噪声向量进行拼接。

(3) 定义判别器。判别器使用卷积层逐步提取特征,最后通过两个全连接层分别输出样本真伪的概率(源判别输出)和类别概率(类别判别输出):

class Discriminator(nn.Module):def __init__(self, num_classes):super(Discriminator, self).__init__()# 卷积层提取特征self.features = nn.Sequential(# 输入: 1x28x28nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 16x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32x7x7nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(32, 0.8),nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 64x4x4nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(64, 0.8),nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 128x2x2nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25),nn.BatchNorm2d(128, 0.8),)# 计算特征图大小: 128 * 2 * 2 = 512self.feature_size = 128 * 2 * 2# 输出真实/虚假的概率self.adv_layer = nn.Sequential(nn.Linear(self.feature_size, 1), nn.Sigmoid())# 输出类别概率self.aux_layer = nn.Sequential(nn.Linear(self.feature_size, num_classes), nn.Softmax(dim=1))def forward(self, img):# 提取特征features = self.features(img)features = features.view(features.size(0), -1)  # 展平# 预测真伪和类别validity = self.adv_layer(features)label = self.aux_layer(features)return validity, label

(4) 初始化生成器和判别器,并打印模型结构:

generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 打印模型结构
print("Generator structure:")
print(generator)
print("\nDiscriminator structure:")
print(discriminator)

输出模型结构如下所示:

模型结构

3.3 模型训练

(1) 初始化损失函数和优化器:

# 定义损失函数
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()# 定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

(2) 定义变量记录训练过程的损失变化:

G_losses = []
D_losses = []

(3) 实现训练循环。训练过程分为两个部分,先训练判别器,使其能正确区分真实和生成样本,并正确分类;然后训练生成器,使其能生成被判别器判定为真实且分类正确的样本:

# 训练循环
for epoch in range(num_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 准备真实/虚假标签valid = torch.ones(batch_size, 1).to(device)fake = torch.zeros(batch_size, 1).to(device)# 真实图像和标签real_imgs = imgs.to(device)real_labels = labels.to(device)#  训练判别器optimizer_D.zero_grad()# 真实样本的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, real_labels)) / 2# 生成虚假样本z = torch.randn(batch_size, latent_dim).to(device)gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)gen_imgs = generator(z, gen_labels)# 虚假样本的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 总判别器损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([real_labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()#  训练生成器optimizer_G.zero_grad()# 生成器希望判别器将虚假样本判断为真实validity, pred_label = discriminator(gen_imgs)g_loss = (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels)) / 2g_loss.backward()optimizer_G.step()# 记录损失G_losses.append(g_loss.item())D_losses.append(d_loss.item())# 打印训练状态if i % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}, acc: {100*d_acc:.2f}%] "f"[G loss: {g_loss.item():.4f}]")# 定期保存生成的图像样本batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:# 保存生成器生成的图像save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)

(4) 训练完成后,保存模型权重:

torch.save(generator.state_dict(), "models/generator_final.pth")
torch.save(discriminator.state_dict(), "models/discriminator_final.pth")

(5) 绘制训练过程中的损失曲线

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("loss_curve.png")
plt.show()

训练过程监测

3.4 模型测试

images 文件夹中可以看到训练过程中生成的样本,随着训练进行,生成的数字越来越清晰:

模型训练过程

使用训练完成的模型,生成制定类别的数字:

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, i)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: {i}")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()

生成结果

生成数字 1

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):img = generate_digit(generator, 1)ax = axes[i//5, i%5]ax.imshow(img.cpu().squeeze(), cmap='gray')ax.set_title(f"Digit: 1")ax.axis('off')
plt.tight_layout()
plt.savefig("generated_digits.png")
plt.show()

生成结果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(21)——归一化流模型(Normalizing Flow Model)
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/96333.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/96333.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python + 淘宝 API 开发:自动化采集商品数据的完整流程​

在电商数据分析、竞品监控和市场调研等场景中,高效采集淘宝商品数据是关键环节。本文将详细介绍如何利用 Python 结合 API,构建一套自动化的商品数据采集系统,涵盖从 API 申请到数据存储的完整流程,并提供可直接运行的代码实现。​…

2025.8.21总结

工作一年多了,在这期间,确实也有不少压力,但每当工作有压力的时候,最后面都会解决。好像每次遇到解决不了的事情,都有同事给我兜底。这种压力,确实会加速一个人的成长。这种狼性文化,这种环境&a…

VS2022 - C#程序简单打包操作

文章目录VS2022 - C#程序简单打包操作概述笔记实验过程新建工程让依赖的运行时程序安装包在安装时运行(如果发现运行时不能每次都安装程序,就不要做这步)关于”运行时安装程序无法每次都安装成功“的应对知识点尝试打包旧工程bug修复从需求属性中,可以原…

在JAVA中如何给Main方法传参?

一、在IDEA中进行传参:先创建一个类:MainTestimport java.util.Arrays;public class MainTest {public static void main(String[] args) {System.out.println(args.length);System.out.println(Arrays.toString(args));} }1.IDEA ---> 在运行的按钮上…

ORACLE中如何批量重置序列

背景:数据库所有序列都重置为1了,所以要将所有的序列都更新为对应的表主键(这里是id)的最大值1。我这里序列的规则是SEQ_表名。BEGINENHANCED_SYNC_SEQUENCES(WJ_CPP); -- 替换为你的模式名 END; / CREATE OR REPLACE PROCEDURE E…

公号文章排版教程:图文双排、添加图片超链接、往期推荐、推文采集(2025-08-21)

文章目录 排版的基本原则 I 图片超链接 方式1: 利用公号原生编辑器 方式2:在CSDN平台使用markdown编辑器, 利用标签实现图片链接。 II 排版小技巧 自定义页面模版教程 使用壹伴进行文章素材的采集 美编助手的往期推荐还不错 利用365编辑器创建图文双排效果 排版的基本原则 亲…

计算两幅图像在特定交点位置的置信度评分。置信度评分反映了该位置特征匹配的可靠性,通常用于图像处理任务(如特征匹配、立体视觉等)

这段代码定义了一个名为compute_confidence的函数,用于计算两幅图像在特定交点位置的置信度评分。置信度评分反映了该位置特征匹配的可靠性,通常用于图像处理任务(如特征匹配、立体视觉等)。以下是逐部分解析: 3. 结果…

计算机视觉第一课opencv(三)保姆级教学

简介 计算机视觉第一课opencv(一)保姆级教学 计算机视觉第一课opencv(二)保姆级教学 今天继续学习opencv。 一、 图像形态学 什么是形态学:图像形态学是一种处理图像形状特征的图像处理技术,主要用于描…

24.早期目标检测

早期目标检测 第一步,计算机图形学做初步大量候选框,把物体圈出来 第二步,依次将所有的候选框图片,输入到分类模型进行判断 选择性搜索 选择搜索算法(Selective Search),是一种熟知的计算机图像…

Java基础知识点汇总(三)

一、面向对象的特征有哪些方面 Java中面向对象的特征主要包括以下四个核心方面:封装(Encapsulation) 封装是指将对象的属性(数据)和方法(操作)捆绑在一起,隐藏对象的内部实现细节&am…

GEO优化专家孟庆涛:让AI“聪明”选择,为企业“精准”生长

在生成式AI席卷全球的今天,企业最常遇到的困惑或许是:“为什么我的AI生成内容总像‘模板套娃’?”“用户明明想要A,AI却拼命输出B?”当生成式AI从“能用”迈向“好用”的关键阶段,如何让AI真正理解用户需求…

【交易系统系列04】交易所版《速度与激情》:如何为狂飙的BTC交易引擎上演“空中加油”?

交易所版《速度与激情》:如何为狂飙的BTC交易引擎上演“空中加油”? 想象一下这个场景:你正端着一杯热气腾腾的咖啡,看着窗外我家那只贪睡的橘猫趴在阳光下打着呼噜。突然,手机上的警报开始尖叫,交易系统监…

windows下jdk环境切换为jdk17后,临时需要jdk1.8的处理

近段时间,终于决定把开发环境全面转向jdk17,这不就遇到了问题。 windows主环境已经设置为jdk17了。 修改的JAVA_HOME D:\java\jdk-17CLASSPATH设置 .;D:\java\jdk-17\lib\dt.jar;D:\java\jdk-17\lib\tools.jar;PATH中增加 D:\java\jdk-17\bin但是有些程序…

Android URC 介绍及源码案例参考

1. URC 含义 URC 是 Unsolicited Result Code(非请求结果码)的缩写。 它是 modem(基带)在不需要 AP 主动请求的情况下向上层主动上报的消息。 典型例子:短信到达提示、网络状态变更、来电通知、信号质量变化等。 URC 一般以 AT 命令扩展的形式从 modem 发到 AP,例如串口…

VB.NET发送邮件给OUTLOOK.COM的用户,用OUTLOOK.COM邮箱账号登录给别人发邮件

在VB.NET中通过代码发送邮件时,确实会遇到邮箱服务的身份认证(Authentication)要求。特别是微软Outlook/Hotmail等服务,已经逐步禁用传统的“基本身份验证”(Basic Authentication),转而强制要求…

【网络运维】Shell:变量进阶知识

Shell 变量进阶知识 Shell 中的特殊变量 位置参数变量 Shell 脚本中常用的位置参数变量如下: $0:获取当前执行的 Shell 脚本文件名(包含路径时包括路径)$n:获取第 n 个参数值(n>9 时需使用 ${n}&#xf…

部署Qwen2.5-VL-7B-Instruct-GPTQ-Int3

模型下载 from modelscope import snapshot_download model_dir snapshot_download(ChineseAlpacaGroup/Qwen2.5-VL-7B-Instruct-GPTQ-Int3)相关包导入 import os import numpy as np import pandas as pd from tqdm import tqdm from datetime import datetime,timedelta fro…

sourcetree 拉取代码

提示:文章旨在于教授大家 sourcetree 拉取代码的方式,关于代码的提交合并等操作后续会补充。 文章目录前言一、sourcetree 安装二、http 与 ssh 拉取代码1.http 方式(1)生成 token(2)拼接项目的 url&#x…

epoll模型网络编程知识要领

1、程序初始化创建监听socket调用bind函数绑定ip地址、port端口号调用listen函数监听调用epoll_create函数创建epollfd调用epoll_ctrl函数将listenfd绑定到epollfd上,监测listenfd的读事件在一个无限循环中,调用epoll_wait函数等待事件发生2、处理客户端…

15-day12LLM结构变化、位置编码和投机采样

多头机制transformer结构归一化层选择 归一化层位置归一化层类型激活函数Llama2结构MoE架构 混合专家模型DeepSeek MLA为何需要位置编码目前的主流位置编码正余弦位置编码可学习位置编码ROPE旋转位置编码推导参考: https://spaces.ac.cn/archives/8265 https://zhua…