生成对抗网络详解与实现

生成对抗网络详解与实现

    • 0. 前言
    • 1. GAN 原理
    • 2. GAN 架构
    • 3. 损失函数
      • 3.1 判别器损失
      • 3.2 生成器损失
      • 3.4 VANILLA GAN
    • 4. GAN 训练步骤

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN) 是图像和视频生成中的主要方法之一。在本节中,我们将了解 GAN 的架构、训练步骤等,并实现原始 GAN

1. GAN 原理

生成模型的目的是学习数据分布并从中进行采样以生成新数据。PixelCNN 和变分自编码器 (Variational Autoencoder, VAE),它们的生成部分将着眼于训练过程中的图像分布。因此,称为显式密度模型 (explicit density models)。相比之下,GAN 中的生成部分不会直接查看图像。因此,GAN 被归类为隐式密度模型 (implicit density models)。
我们可以使用一个类比来比较显式模型和隐式模型。假设一位艺术系学生 G 获得了毕加索的画作收藏,并被要求学习绘制假毕加索画作。学生可以在学习绘画时查看收藏,因此这是一个显式模型。在另一种情况下,我们要求学生 G 伪造毕加索的画,但我们没有给他们看任何画,他们也不知道毕加索的画是什么样。他们学习的唯一方法是学生 D 的反馈,后者正在学习判别假毕加索的画作。反馈很简单——这幅画是假的还是真实的。这就是我们的隐式密度 GAN 模型。
也许有一天,G 偶然地画了一张扭曲的脸,并从反馈中得知它看起来像一幅真正的毕加索画,然后他们开始以这种方式来欺骗学生 D。学生 GDGAN 中的两个网络,称为生成器和判别器。与其他生成模型相比,这是网络体系结构的最大区别。我们将从了解 GAN 构建块开始,然后介绍损失函数。然后,我们将为 GAN 创建自定义的训练步骤。

2. GAN 架构

生成对抗网络中的对抗一词是指包含对立或异议。有两个相互竞争的网络,称为生成器和判别器。顾名思义,生成器生成伪造的图像。而辨别器将查看生成的图像,以确定它们是真实的还是伪造的。每个网络都试图赢得这场比赛,判别器要正确识别每个真实和伪造的图像,而生成器则要愚弄判别器以使其所产生的虚假图像被判别器判定是真实的。下图显示了 GAN 的体系结构:

GAN

GAN 架构与 VAE 有一些相似之处。如果 VAE 由两个独立的网络组成,我们可以想到:

  • GAN 的生成器作为 VAE 的解码器
  • GAN 的判别器作为 VAE 的编码器

生成器将低维和简单分布转换为具有复杂分布的高维图像,就像解码器一样。生成器的输入通常是来自正态分布的样本,也有些样本使用均匀分布。
我们将不同批次的图像发送给判别器。真实图像是来自数据集的图像,而伪造图像则是由生成器生成的。判别器输出输入图片是真还是假的单值概率。它是一个二进制分类器,可以使用 CNN 来实现它。从技术上讲,判别器的作用与编码器不同,但它们都减小了输入的维数。
实际上,原始的 GAN 仅使用了多层感知器,该感知器由一些基本的全连接层组成。

3. 损失函数

损失函数体现了 GAN 的工作原理。公式如下:
minGmaxDV(D,G)=EX∼Pdata(x)[logD(x)]+EZ∼Pz(z)[log(1−D(G(z)))]min_Gmax_DV(D,G)=E_{X\sim P_data(x)}[logD(x)]+E_{Z\sim P_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=EXPdata(x)[logD(x)]+EZPz(z)[log(1D(G(z)))]

其中:DDD 表示判别器,GGG 表示生成器,xxx 表示输入数据,zzz 表示潜变量。
了解 GAN 的损失函数之后,代码实现将变得更加容易。此外,有关 GAN 改进的许多讨论都围绕损失函数进行。GAN 损失函数也称为对抗损失。接下来,我们将对其进行分解,并逐步向展示如何将其转换为我们可以实现的简单损失函数。

3.1 判别器损失

GAN 损失函数的等式右侧的第一项是用于正确分类真实图像的值。从等式左边的项来看,我们知道判别器想要将其最大化。期望是一个数学术语,是随机变量每个样本的加权平均值之和。在此等式中,权重是数据的概率,而变量是判别器输出的对数,如下所示:
EX[logD(x)]=∑i=1Np(x)logD(x)=1N∑i=1NlogD(x)E_X[logD(x)]=\sum_{i=1}^Np(x)logD(x)=\frac 1N\sum_{i=1}^NlogD(x) EX[logD(x)]=i=1Np(x)logD(x)=N1i=1NlogD(x)
在大小为 NNN 的小批次中,p(x)p(x)p(x)1N\frac 1 NN1。这是因为 xxx 是单个图像。不必尝试使它最大化,我们可以将符号更改为减号并尝试使其最小化。这可以借助以下方程来完成,该方程称为对数损失:
minDV(D)=−1N∑i=1NlogD(x)=−1N∑i=1Nyilogp(yi)min_DV(D)=-\frac 1N\sum_{i=1}^NlogD(x)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i) minDV(D)=N1i=1NlogD(x)=N1i=1Nyilogp(yi)
其中:yiy_iyi 是标签,对于真实图像为 1p(yi)p(y_i)p(yi) 是样本为真的概率。
GAN 损失函数的等式右侧的第二项是关于伪造图像的。zzz 是随机噪声,并且 G(z)G(z)G(z) 是生成图像。D(G(z))D(G(z))D(G(z)) 是判别器对图像真实可能性的置信度得分。如果我们将标签 0 用于伪造图像,则可以使用相同的方法将其转换为以下等式:
−EZ∼Pz(z)[log(1−D(G(z))]=−1N∑i=1N(1−yi)log(1−p(yi))-E_{Z\sim P_z(z)}[log(1-D(G(z))]=-\frac 1N\sum_{i=1}^N(1-y_i)log(1-p(y_i)) EZPz(z)[log(1D(G(z))]=N1i=1N(1yi)log(1p(yi))
现在,将所有内容放在一起,我们有了判别器损失函数,即二进制交叉熵损失:
minDV(D)=−1N∑i=1Nyilogp(yi)+(1−yi)log(1−p(yi))min_DV(D)=-\frac 1N\sum_{i=1}^Ny_ilogp(y_i)+(1-y_i)log(1-p(y_i)) minDV(D)=N1i=1Nyilogp(yi)+(1yi)log(1p(yi))
使用以下代码实现判别器损失:

def discriminator_loss(pred_fake, pred_real):real_loss = bce(tf.ones_like(pred_real), pred_real)fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)d_loss = 0.5 *(real_loss + fake_loss)return d_loss

在我们的训练中,我们使用相同的批大小分别对真实和伪造图像进行前向传递。因此,我们分别为它们计算二进制交叉熵损失,并取平均值作为损失。

3.2 生成器损失

仅当模型判别伪造图像时才涉及生成器,因此我们只需要查看 GAN 损失函数的等式右侧第二项并将其简化为:
minGV(G)=EZ∼Pz(z)[log(1−D(G(z))]min_GV(G)=E_{Z\sim P_z(z)}[log(1-D(G(z))] minGV(G)=EZPz(z)[log(1D(G(z))]
在训练开始时,生成器并不擅长生成图像,因此判别器始终有信心将其归类为 0,使 D(G(z))D(G(z))D(G(z)) 始终为 0log(1–0)log(1 – 0)log(1–0) 也是如此。当模型输出中的误差始终为 0 时,则没有反向传播的梯度。结果,生成器的权重未更新,并且生成器未学习。由于判别器的 sigmoid 输出几乎没有梯度,因此这种现象称为梯度饱和 (saturating gradient)。为避免此问题,将等式从最小化 1−D(G(z))1-D(G(z))1D(G(z)) 到最大化 D(G(z))D(G(z))D(G(z)) 进行如下转换:
maxGV(G)=EZ∼Pz(z)[logD(G(z))]max_GV(G)=E_{Z\sim P_z(z)}[logD(G(z))] maxGV(G)=EZPz(z)[logD(G(z))]
使用此函数的 GAN 也称为非饱和 GAN (Non-Saturating GANs, NS-GAN)。实际上,Vanilla GAN 的实现都使用此损失函数而不是原始的 GAN 损失函数。

3.4 VANILLA GAN

GAN 诞生后,研究人员对 GAN 的兴趣激增,提出了一系列改进模型。Vanilla GAN 是泛指基本 GANVanilla GAN 通常使用具有两个或三个隐藏的全连接层来实现。
我们可以对判别器使用相同的数学步骤来推导生成器损失,最终将得到相同的判别器损失函数,只是将标签 1 用于伪造图像。为什么要对伪造图片使用标签 1,我们也可以这样理解它——因为我们想欺骗判别器以假定那些生成的图像是真实的,因此我们使用标签 1

def generator_loss(pred_fake):g_loss = bce(tf.ones_like(pred_fake), pred_fake)return g_loss

4. GAN 训练步骤

为了在 TensorFlow 中训练神经网络,我们需要指定模型,损失函数,优化器,然后调用 model.fit()TensorFlow 将为我们完成所有工作,我们等待损失减少。
在研究 GAN 问题之前,我们首先回顾神经网络在进行单个训练步骤时代码执行的情况:

  • 执行前向传播以计算损失
  • 使用损失相对于权重的梯度向后传播
  • 然后,这是更新权重。优化器将缩放梯度并将其添加到权重中,从而完成一个训练步骤

这些是深度神经网络中的通用训练步骤。各种优化器的不同之处仅在于它们计算缩放因子的方式。
现在回到 GAN,查看梯度流。当我们训练真实图像时,只涉及判别器–网络输入是真实图像,输出是 1 的标签。当我们使用伪造图像并且梯度通过判别器反向传播到生成器时,就会出现问题。让我们将伪造图像的生成器损失和判别器损失并排放置:

g_loss = bce(tf.ones_like(pred_fake), pred_fake)
fake_loss = bce(tf.zeros_like(pred_fake), pred_fake)

可以发现它们之间的差异,它们的标签是相反!这意味着,使用生成器损失来训练整个模型将使判别器朝相反的方向移动,而不会学会进行判别。这适得其反,我们不想有一个未经训练的判别器,这会阻止生成器学习。因此,我们必须分别训练生成器和判别器。训练生成器时,我们将冻结判别器权重。
有多种方法可以设计 GAN 训练流程。一种是使用高级 Keras 模型,该模型需要较少的代码,因此看起来更优雅。我们只需要定义一次模型,然后调用 train_on_batch() 即可执行所有步骤,包括前向计算,反向传播和权重更新。但是,在实现更复杂的损失函数时,灵活性较差。
另一种方法是使用低级函数,以便控制每个步骤。在本节中,GAN 将使用自定义训练步骤:

def train_step(g_input, real_input):with tf.GradientTape() as g_tape,\tf.GradientTape() as d_tape:# Forward passfake_input = G(g_input)pred_fake = D(fake_input)pred_real = D(real_input)   # Calculate lossesd_loss = discriminator_loss(pred_fake, pred_real)g_loss = generator_loss(pred_fake)

tf.GradientTape() 用于记录单次通过的梯度。另一个具有类似功能的 APItf.Gradient(),但后者在 TensorFlow Eager 执行中不起作用。我们将看到如何在 train_step() 中实现前面提到的三个过程步骤。前面的代码段显示了执行前向传递以计算损失的第一步。
第二步是使用 tape 梯度从它们各自的损失计算生成器和判别器的梯度:

        gradient_g = g_tape.gradient(g_loss, G.trainable_variables)gradient_d = d_tape.gradient(d_loss, D.trainable_variables)

第三步也是最后一步是使用优化器将梯度应用于模型权重:

        G_optimizer.apply_gradients(zip(gradient_g, self.G.trainable_variables))D_optimizer.apply_gradients(zip(gradient_d, self.D.trainable_variables))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/96575.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/96575.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FPGA硬件开发-XPE工具的使用

目录 XPE 工具概述​ XPE 使用步骤详解​ 1. 工具获取与初始化​ 2. 器件选择与配置​ 3. 电源电压设置​ 4. 资源使用量配置​ 5. 时钟与开关活动配置​ 6. 功耗计算与报告生成​ 报告解读与电源设计优化​ 常见问题与最佳实践​ 与实际功耗的差异处理​ 工具版本…

CentOS 7.9 RAID 10 实验报告

文章目录CentOS 7.9 RAID 10 实验报告一、实验概述1.1 实验目的1.2 实验环境1.3 实验拓扑二、实验准备2.1 磁盘准备2.2 安装必要软件三、RAID 10阵列创建3.1 创建RAID 10阵列3.2 创建文件系统并挂载3.3 保存RAID配置四、性能基准测试4.1 初始性能测试4.2 创建测试数据集五、故障…

机器人逆运动学进阶:李代数、矩阵指数与旋转流形计算

做机器人逆运动学(IK)的时候,你迟早会遇到矩阵指数和对数这些东西。为什么呢?因为计算三维旋转的误差,不能简单地用欧氏距离那一套,那只对位置有效。旋转得用另一套方法——你需要算两个旋转矩阵之间的差异…

计算机视觉(opencv)实战十八——图像透视转换

图像透视变换详解与实战在图像处理中,透视变换(Perspective Transform) 是一种常见的几何变换,用来将图像中某个四边形区域拉伸或压缩,映射到一个矩形区域。常见应用场景包括:纠正拍照时的倾斜(…

【飞书多维表格插件】

coze中添加飞书多维表格记录插件 添加单条记录 [{"fields":{"任务详情":"选项1","是否完成":"未完成"}}]添加多条记录 [{"fields":{"任务详情":"选项1","是否完成":"已完…

Java基础 9.14

1.Collection接口遍历对象方式2-for循环增强增强for循环,可以代替iterator选代器,特点:增强for就是简化版的iterator本质一样 只能用于遍历集合或数组package com.logic.collection_;import java.util.ArrayList; import java.util.Collectio…

数据结构(C语言篇):(十三)堆的应用

目录 前言 一、堆排序 1.1 版本一:基于已有数组建堆、取栈顶元素完成排序 1.1.1 实现逻辑 1.1.2 底层原理 1.1.3 应用示例 1.1.4 执行流程 1.2 版本二:原地排序 —— 标准堆排序 1.2.1 实现逻辑 1.2.2 底层原理 1.2.3 时间复杂度计算…

4步OpenCV-----扫秒身份证号

这段代码用 OpenCV 做了一份“数字模板字典”,然后在银行卡/身份证照片里自动找到身份证号那一行,把每个数字切出来跟模板比对,最终输出并高亮显示出完整的身份证号码,下面是代码解释:模块 1 工具箱(通用函…

冯诺依曼体系:现代计算机的基石与未来展望

冯诺依曼体系:现代计算机的基石与未来展望 引人入胜的开篇 当你用手机刷视频、用电脑办公时,是否想过这些设备背后共享的底层逻辑?从指尖轻滑切换APP,到电脑秒开文档,这种「无缝衔接」的体验,其实藏着一个改…

前端基础 —— C / JavaScript基础语法

以下是对《3.JavaScript(基础语法).pdf》的内容大纲总结:---📘 一、JavaScript 简介 - 定义:脚本语言,最初用于表单验证,现为通用编程语言。 - 应用:网页开发、游戏、服务器(Node.js&#xff09…

springboot 二手物品交易系统设计与实现

springboot 二手物品交易系统设计与实现 目录 【SpringBoot二手交易系统全解析】从0到1搭建你的专属平台! 🔍 需求确认:沟通对接 🗣 📊 系统功能结构:附思维导图 ☆开发技术: &#x1f6e…

【Android】可折叠式标题栏

在 Android 应用开发中,精美的用户界面可以显著提升应用品质和用户体验。Material Design 组件中的 CollapsingToolbarLayout 能够为应用添加动态、流畅的折叠效果,让标题栏不再是静态的元素。本文将深入探讨如何使用 CollapsingToolbarLayout 创建令人惊…

Debian13下使用 Vim + Vimspector + ST-LINK v2.1 调试 STM32F103 指南

1. 硬件准备与连接 1.1 所需硬件 STM32F103C8T6 最小系统板ST-LINK v2.1 调试器连接线(杜邦线) 1.2 硬件连接 ST-LINK v2.1 ↔ STM32F103C8T6 连接方式:ST-LINK v2.1 引脚STM32F103C8T6 引脚功能说明SWDIOPA13数据线SWCLKPA14时钟线GNDGND共地…

第21课:成本优化与资源管理

第21课:成本优化与资源管理 课程目标 掌握计算资源优化 学习成本控制策略 了解资源调度算法 实践实现成本优化系统 课程内容 21.1 成本分析框架 成本分析系统 class CostAnalysisFramework {constructor(config) {this.config

SAP HANA Scale-out 04:CalculationView优化

CV执行过程计算视图激活时,生成Stored ModelSELECT查询时:首先将Stored Model实例化为runtime Model 计算引擎执行优化,将runtime Model转换为Optimized Runtime ModelOptimized Runtime Model通过SQL Optimizer进行优化计算引擎优化特性说明…

鸿蒙审核问题——Scroll中嵌套了List/Grid时滑动问题

文章目录背景原因解决办法1、借鉴Flutter中的解决方式,如下图2、鸿蒙Next中对应的解决方式,如下图3、官方文档回访背景 来源一次审核被拒的情况。也是出于粗心导致的。之前在flutter项目中也是遇到过这种问题的。其实就是滚动视图内嵌滚动视图造成的&am…

测试电商购物车功能,设计测试case

在电商场景中,购物车是连接商品浏览与下单支付的关键环节,需要从功能、性能、兼容性、安全性等多维度进行测试。以下是购物车功能的测试用例设计: 一、功能测试 1. 商品添加到购物车 - 未登录状态下,添加商品到购物车(…

Linux --- 常见的基本指令

一. 前言本篇博客使用的 Linux 操作系统是 centos ,用来学习Linux 的 Linux 系统的内核版本和系统架构信息版本如下所示:上图的主要结构为:主版本号-次版本号 修正次数,3.10.0 是操作系统的主版本号;当我们在维护一段L…

微信小程序 -开发邮箱注册验证功能

一、前端验证:正则表达式与插件结合正则表达式设计 使用通用邮箱格式校验正则,并允许中文域名(如.中国): const emailReg /^[a-zA-Z0-9._%-][a-zA-Z0-9-](?:\.[a-zA-Z0-9-])*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2})?$/i;…

docker 部署 code-server

docker 部署 code-servercode-serverError response from daemon: Get "https://registry-1.docker.io/v2/": net/http: request canceled while waiting for connection (Client.Timeout exceeded while awaiting headersdocker 配置正确步骤 阿里云源permission de…