生成式人工智能实战 | 生成对抗网络(Generative Adversarial Network, GAN)

生成式人工智能实战 | 生成对抗网络

    • 0. 前言
    • 1. 生成对抗网络
    • 2. 模型构建
      • 2.1 生成器
      • 2.2 判别器
    • 3. 模型训练
      • 3.1 数据加载
      • 3.2 训练流程

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. 生成对抗网络

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像:

  • 生成器 G ( z ; θ G ) G(z;θ_G) G(z;θG) 接受噪声 z ∼ p z z∼p_z zpz​,学习映射到数据空间,以“欺骗”判别器
  • 判别器 D ( x ; θ D ) D(x;θ_D) D(x;θD) 输出样本 x x x 属于真实数据的概率,旨在区分真实与生成数据
    两者通过以下最小–最大化 (minimax) 目标函数进行博弈:
    m i n G m a x D ⁡ V ( D , G ) = E x ∼ p d a t a [ l o g ⁡ D ( x ) ] + E z ∼ p z [ l o g ⁡ ( 1 − D ( G ( z ) ) ) ] \underset G {min} \underset D {max} ⁡V(D,G)=\mathbb E_{x∼p_{data}}[log⁡D(x)]+\mathbb E_{z∼p_z}[log⁡(1−D(G(z)))] GminDmaxV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]
    生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

GAN原理

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. 模型构建

2.1 生成器

生成器由若干全连接层与 LeakyReLU 激活构成,最后用 Tanh 将输出映射至 [−1,1] 范围内:

# 定义生成器 G
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=10):super().__init__()self.net = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh()  # 输出像素映射到 [-1,1])def forward(self, z):return self.net(z).view(-1,1,28,28)

2.2 判别器

判别器使用全连接层与 LeakyReLU 激活,末端使用 Sigmoid 激活函数,输出一个标量真值估计:

# 定义判别器 D,对输入图片输出真伪概率
class Discriminator(nn.Module):def __init__(self, img_dim=28*28):super().__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(img_dim, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)

3. 模型训练

接下来,使用 MNIST 数据集训练 GAN 模型。

3.1 数据加载

MNIST 像素值归一化到生成器 Tanh 输出所需的 [−1,1] 区间:

# 加载并归一化 MNIST 数据集
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 映射到 [-1,1]
])
train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)

3.2 训练流程

首先,初始化模型、优化器与损失函数:

import torch
import torch.optim as optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
z_dim = 20
G = Generator(z_dim).to(device)
D = Discriminator().to(device)opt_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()

训练模型 50epoch

epochs = 50for epoch in range(epochs):for real, _ in train_loader:real = real.to(device)batch_size = real.size(0)real_labels = torch.ones(batch_size, 1, device=device)fake_labels = torch.zeros(batch_size, 1, device=device)# 训练判别器# 在真实样本上进行训练D_real = D(real)loss_D_real = loss_fn(D_real, real_labels)opt_D.zero_grad()loss_D_real.backward()opt_D.step()# 在虚假样本上进行训练z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake.detach())loss_D_fake = loss_fn(D_fake, fake_labels)opt_D.zero_grad()loss_D_fake.backward()opt_D.step()d_loss = (loss_D_real + loss_D_fake) / 2# 训练生成器z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake)loss_G = loss_fn(D_fake, real_labels)  # 生成器希望D认为它生成的是真的opt_G.zero_grad()loss_G.backward()opt_G.step()print(f"Epoch [{epoch+1}/{epochs}]  Loss_D: {d_loss.item():.4f}  Loss_G: {loss_G.item():.4f}")

使用训练后的模型生成伪造数据:

# 采样生成图片并显示
import matplotlib.pyplot as pltG.eval()
with torch.no_grad():z = torch.randn(16, z_dim, device=device)fake_images = G(z).cpu()fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i].squeeze().reshape(28, 28), cmap='gray')ax.axis('off')
plt.tight_layout()
plt.show()

生成结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/86496.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/86496.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

缓存与加速技术实践-MongoDB数据库应用

一.什么是MongoDB MongoDB 是一个文档型数据库,数据以类似 JSON 的文档形式存储。 MongoDB 的设计理念是为了应对大数据量、高性能和灵活性需求。 MongoDB 使用集合(Collections)来组织文档(Documents)&#xff0…

声网对话式AI把“答疑机器人”变成“有思维的助教”

作为一家专注初高中学生的线上教育平台,我们精心打磨的系统化课程收获了不少认可,但课后无人答疑的难题却始终横亘在前。学生课后遇到疑惑,要么只能默默憋在心里,要么就得苦苦等待下一节课,家长们也频繁抱怨 “花了钱&…

常见的排序方法

目录 1. 插入排序 2. 希尔排序 3. 选择排序 4. 堆排序 5. 冒泡排序 6. 快速排序 1. 快速排序的实现 1. 思路(以从小到大排序为例) 2. 选取基准元素的方法(Hoare) 3. 选取基准元素的方法(挖坑法) …

【matlab定位例程】基于AOA和TDOA混合的定位方法,背景为三维空间,自适应锚点数量,附下载链接

文章目录 代码概述代码功能概述核心算法原理AOA定位模型TDOA定位迭代算法混合定位策略关键技术创新 运行结果4个锚点的情况40个锚点的情况 MATLAB源代码 代码概述 代码功能概述 本代码实现了一种三维空间中的混合定位算法,结合到达角( A O A AOA AOA&a…

专题:2025医疗AI应用研究报告|附200+份报告PDF汇总下载

原文链接:https://tecdat.cn/?p42748 本报告汇总解读聚焦医疗行业人工智能应用的前沿动态与市场机遇,以数据驱动视角剖析技术演进与商业落地的关键路径。从GenAI在医疗领域的爆发式增长,到细分场景的成熟度矩阵,再到运营成本压力…

推荐一个前端基于vue3.x,vite7.x,后端基于springboot3.4.x的完全开源的前后端分离的中后台管理系统基础项目(纯净版)

XHan Admin 简介 🎉🎉 XHan Admin 是一个开箱即用的开源中后台管理系统基础解决方案, 项目为前后端分离架构。采用最新的技术栈全新构建,纯净的项目代码,没有历史包袱。 前端使用最新发布的 vite7.0 版本构建&#xf…

MySQL误删数据急救指南:基于Binlog日志的实战恢复详解

背景 数据误删是一个比较严重的场景 1.典型误操作场景 场景1:DELETE FROM orders WHERE status0 → 漏写AND create_time>‘2025-06-20’ 场景2:DROP TABLE customer → 误执行于生产环境 认识 binlog 1.binlog 的核心作用 记录所有 DDL/DML 操…

高效数据采集方案:快速部署与应用 AnyCrawl 网页爬虫工具实操指南

以下是对 AnyCrawl 的简单介绍: AnyCrawl 提供高性能网页数据爬取,其功能专为 LLM 集成和数据处理而设计支持利用搜索引擎直接查询获取结果内容,类似 searxng提供开发者友好的API,支持动态内容抓取,并输出结构化数据&…

vue3可以分页、搜索的select

下载 npm i v-selectpage基本使用 import { SelectPageList } from v-selectpage;<SelectPageListlanguage"zh-chs"key-prop"id"label-prop"name"fetch-data"fetchData" />const fetchData (data,callback) > {const { sea…

C# 入门学习教程 (一)

文章目录 一、解决方案与项目1. Solution 与 project 二、类与名称空间1.类与名称空间2.类库的引用1. DLL引用&#xff08;黑盒引用&#xff0c;无源代码&#xff09;2. Nuget 引用3. 项目引用&#xff08;白盒引用&#xff0c;有源代码&#xff09; 3.依赖关系 三、类&#xf…

76、单元测试-参数化测试

76、单元测试-参数化测试 参数化测试是一种单元测试技术&#xff0c;通过将测试数据与测试逻辑分离&#xff0c;使用不同的输入参数多次运行相同的测试用例&#xff0c;从而提高测试效率和代码复用性。 #### 基本原理 - **数据驱动测试**&#xff1a;将测试数据参数化&#xf…

SQL学习笔记3

SQL常用函数 1、字符串函数 函数调用的语法&#xff1a;select 函数&#xff08;参数); 常用的字符串函数有&#xff1a; 拼接字符串&#xff0c;将几个字符串拼到一起&#xff1a;concat (s1,s2,……); select concat(你好,hello); update mytable set wherefo concat(中…

Golang 面向对象编程,如何实现 封装、继承、多态

Go语言虽然不是纯粹的面向对象语言&#xff0c;但它通过结构体(struct)、接口(interface)和方法(method)提供了面向对象编程的能力。下面我将通过具体示例展示Go中如何实现类、封装、继承、多态以及构造函数等概念。 1. 类与封装 在Go中&#xff0c;使用结构体(struct)来定义…

为什么android要使用Binder机制

1.linux中大多数标准 IPC 场景&#xff08;如管道、消息队列、ioctl 等&#xff09;的进程间通信机制 ------------------ ------------------ ------------------ | 用户进程 A | | 内核空间 | | 用户进程 B | | (User Spa…

OpenCV CUDA模块设备层-----双曲余弦函数cosh()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 该函数用于计算四维浮点向量&#xff08;float4类型&#xff09;的双曲余弦值&#xff0c;作用于CUDA设备端。双曲余弦函数定义为cosh(x) (eˣ …

48页PPT | 企业数字化转型关键方法论:实践路径、案例和落地评估框架

目录 一、什么是企业数据化转型&#xff1f; 二、为什么要进行数据化转型&#xff1f; 1. 市场复杂性与不确定性上升 2. 内部流程效率与协同难题突出 3. 数字资产沉淀不足&#xff0c;智能化基础薄弱 三、数据化流程管理&#xff1a;从“业务流程”到“数据流程”的对齐 …

VTK中的形态学处理

VTK图像处理代码解析:阈值化与形态学开闭运算 这段代码展示了使用VTK进行医学图像处理的两个关键步骤:阈值分割和形态学开闭运算。下面我将详细解析每个部分的功能和实现原理。 处理前 处理后 1. 阈值分割部分 (vtkImageThreshold) vtkSmartPointer<vtkImageThresho…

xlsx.utils.sheet_to_json() 方法详解

sheet_to_json() 是 SheetJS/xlsx 库中最常用的方法之一&#xff0c;用于将 Excel 工作表&#xff08;Worksheet&#xff09;转换为 JSON 格式数据。下面我将全面讲解它的用法、参数配置和实际应用场景。 基本语法 javascript 复制 下载 const jsonData XLSX.utils.sheet…

〔从零搭建〕BI可视化平台部署指南

&#x1f525;&#x1f525; AllData大数据产品是可定义数据中台&#xff0c;以数据平台为底座&#xff0c;以数据中台为桥梁&#xff0c;以机器学习平台为中层框架&#xff0c;以大模型应用为上游产品&#xff0c;提供全链路数字化解决方案。 ✨杭州奥零数据科技官网&#xf…

合规型区块链RWA系统解决方案报告——机构资产数字化的终极武器

&#xff08;跨境金融科技解决方案白皮书&#xff09; 一、直击机构客户四大痛点 痛点传统方案缺陷我们的破局点✖️ 跨境资产流动性差结算周期30天&#xff0c;摩擦成本超8%▶️ 724h全球实时交易&#xff08;速度提升90%&#xff09;✖️ 合规成本飙升KYC/AML人工审核占成本…