生成对抗网络(GANs)中的损失函数公式 判别器最优解D^*(x)的推导

https://www.bilibili.com/video/BV1YyHSekEE2在这里插入图片描述

这张图片展示的是生成对抗网络(GANs)中的损失函数公式,特别是针对判别器(Discriminator)和生成器(Generator)的优化目标。让我们用Markdown格式逐步解析这些公式:

GAN的基本优化目标
markdown
深色版本
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x})}[\log D(\boldsymbol{x})] + \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log(1 - D(G(\boldsymbol{z})))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
min ⁡ G max ⁡ D \min_G \max_D minGmaxD 表示:这是一个最小最大博弈问题,其中生成器 G 和判别器 D 在相互竞争中进行优化。
E x ∼ p d a t a ( x ) \mathbb{E}_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x})} Expdata(x):表示对真实数据分布 p d a t a ( x ) p_{data}(\boldsymbol{x}) pdata(x) 的期望值计算。
E z ∼ p z ( z ) \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})} Ezpz(z):表示对噪声分布 p z ( z ) p_{\boldsymbol{z}}(\boldsymbol{z}) pz(z) 的期望值计算。
D ( x ) D(\boldsymbol{x}) D(x):判别器输出的真实样本的概率。
G ( z ) G(\boldsymbol{z}) G(z):生成器根据噪声 z \boldsymbol{z} z 生成的样本。
log ⁡ D ( x ) \log D(\boldsymbol{x}) logD(x) log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(\boldsymbol{z}))) log(1D(G(z))):分别代表判别器正确识别真实样本和错误识别生成样本的对数概率。
判别器的损失函数
markdown
深色版本
l o s s d = − ( 1 N ∑ i = 1 N y i log ⁡ ( p i ) + ( 1 − y i ) log ⁡ ( 1 − p i ) ) loss_d = -\left(\frac{1}{N}\sum_{i=1}^{N}y_i\log(p_i) + (1-y_i)\log(1-p_i)\right) lossd=(N1i=1Nyilog(pi)+(1yi)log(1pi))
这是一个二分类交叉熵损失函数,用于衡量判别器在区分真实和生成样本时的性能。
y i y_i yi 是标签(1表示真实样本,0表示生成样本), p i p_i pi 是判别器预测的概率。
判别器损失函数的具体形式
markdown
深色版本
l o s s d = − ( 1 N ∑ i = 1 N y i log ⁡ ( D ( i m a g e i ) ) + ( 1 − y i ) log ⁡ ( 1 − D ( i m a g e i ) ) ) loss_d = -\left(\frac{1}{N}\sum_{i=1}^{N}y_i\log(D(image_i)) + (1-y_i)\log(1-D(image_i))\right) lossd=(N1i=1Nyilog(D(imagei))+(1yi)log(1D(imagei)))
这里将 p i p_i pi 替换为 D ( i m a g e i ) D(image_i) D(imagei),即判别器对图像 i m a g e i image_i imagei 的输出概率。
判别器损失函数的进一步分解
markdown
深色版本
min ⁡ l o s s d = − ( 1 N r e a l ∑ i log ⁡ ( D ( x i ) ) + 1 N f a k e ∑ i log ⁡ ( 1 − D ( G ( z i ) ) ) ) \min loss_d = -\left(\frac{1}{N_{real}}\sum_i\log(D(x_i)) + \frac{1}{N_{fake}}\sum_i\log(1-D(G(z_i)))\right) minlossd=(Nreal1ilog(D(xi))+Nfake1ilog(1D(G(zi))))
这个公式明确地将损失分为两部分:一部分是对于真实样本 x i x_i xi 的损失,另一部分是对于生成样本 G ( z i ) G(z_i) G(zi) 的损失。
最大化判别器的目标
markdown
深色版本
max ⁡ 1 N r e a l ∑ i log ⁡ ( D ( x i ) ) + 1 N f a k e ∑ i log ⁡ ( 1 − D ( G ( z i ) ) ) \max \frac{1}{N_{real}}\sum_i\log(D(x_i)) + \frac{1}{N_{fake}}\sum_i\log(1-D(G(z_i))) maxNreal1ilog(D(xi))+Nfake1ilog(1D(G(zi)))
这个公式展示了判别器的目标是最大化其对真实样本的识别能力和对生成样本的拒绝能力。
通过上述公式,我们了解了GAN中判别器和生成器之间的博弈过程,以及如何通过优化损失函数来训练这两个模型,以达到生成高质量样本的目的。


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh())def forward(self, input):return self.main(input)# 定义判别器
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(784, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, input):return self.main(input)# 初始化模型、损失函数和优化器
generator = Generator()
discriminator = Discriminator()criterion = nn.BCELoss()  # Binary Cross Entropy Loss
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=64, shuffle=True)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):for i, (imgs, _) in enumerate(train_loader):# 准备数据valid = torch.ones(imgs.size(0), 1)fake = torch.zeros(imgs.size(0), 1)real_imgs = imgs.view(imgs.size(0), -1)# 训练判别器optimizer_d.zero_grad()z = torch.randn(imgs.size(0), 100)gen_imgs = generator(z)loss_real = criterion(discriminator(real_imgs), valid)loss_fake = criterion(discriminator(gen_imgs.detach()), fake)loss_d = (loss_real + loss_fake) / 2loss_d.backward()optimizer_d.step()# 训练生成器optimizer_g.zero_grad()loss_g = criterion(discriminator(gen_imgs), valid)loss_g.backward()optimizer_g.step()print(f"Epoch [{epoch}/{num_epochs}] Loss D: {loss_d.item()}, loss G: {loss_g.item()}")print('训练完成')

在这里插入图片描述
判别器损失(Loss D)
趋势:判别器损失在训练过程中表现出较大的波动性。初期较低(<0.2),中期升高(最高达到0.5776),然后又回落。
解释:这表明模型在学习过程中经历了不同的阶段,其中生成器在某些时期能够较好地欺骗判别器,导致判别器的损失增加。
生成器损失(Loss G)
趋势:生成器损失从最初的较高水平逐渐下降到最低点(1.5655),然后有所回升并在后续的epoch中保持在一个相对较高的水平(约2.5至4之间)。
解释:这种模式可能意味着生成器的学习速率与判别器相比不够平衡,或者存在过拟合现象。特别是在第4个epoch之后,生成器损失的上升可能表示生成器遇到了瓶颈或难以进一步优化。
🔍 结论与建议
稳定性问题:由于损失值的大幅波动,可能需要调整超参数来稳定训练过程。比如:
调整学习率。
应用梯度惩罚或其他正则化技术以增强训练稳定性。
网络架构或数据集问题:如果损失值持续不稳定,考虑检查数据集是否足够多样化,以及网络架构是否有改进空间。
早期停止策略:可以实施早期停止策略来防止过拟合,并确保模型在验证集上的性能不会恶化。
可视化:定期保存并查看生成样本,可以帮助理解模型的实际表现和进步情况。


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt# 检查是否可用 CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# -----------------------------
# 模型定义
# -----------------------------
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh())def forward(self, input):return self.main(input)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(784, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, input):return self.main(input)# -----------------------------
# 初始化模型和优化器
# -----------------------------
generator = Generator().to(device)
discriminator = Discriminator().to(device)criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)# -----------------------------
# 数据加载
# -----------------------------
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transform),batch_size=64,shuffle=True
)# -----------------------------
# 可视化函数
# -----------------------------
def show_images(images, epoch):images = images.view(-1, 1, 28, 28)grid = utils.make_grid(images, nrow=4, normalize=True)plt.figure(figsize=(5, 5))plt.title(f"Epoch {epoch}")plt.imshow(grid.permute(1, 2, 0).cpu())plt.axis("off")plt.show()# 固定噪声,用于每轮观察生成效果变化
fixed_noise = torch.randn(16, 100, device=device)# -----------------------------
# 训练循环
# -----------------------------
num_epochs = 20for epoch in range(num_epochs):for i, (imgs, _) in enumerate(train_loader):imgs = imgs.to(device)real_labels = torch.ones(imgs.size(0), 1).to(device)fake_labels = torch.zeros(imgs.size(0), 1).to(device)# ---------------------#  训练判别器# ---------------------optimizer_d.zero_grad()real_imgs = imgs.view(imgs.size(0), -1)d_real_loss = criterion(discriminator(real_imgs), real_labels)z = torch.randn(imgs.size(0), 100).to(device)gen_imgs = generator(z).detach()d_fake_loss = criterion(discriminator(gen_imgs), fake_labels)loss_d = (d_real_loss + d_fake_loss) / 2loss_d.backward()optimizer_d.step()# ---------------------#  训练生成器# ---------------------optimizer_g.zero_grad()gen_imgs = generator(z)loss_g = criterion(discriminator(gen_imgs), real_labels)loss_g.backward()optimizer_g.step()print(f"Epoch [{epoch}/{num_epochs}] Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")# 每个epoch结束后可视化生成结果with torch.no_grad():generated = generator(fixed_noise).cpu()show_images(generated, epoch)print("训练完成")

加入可视化
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
https://www.kaggle.com/code/alihhhjj/notebook8be232dcc8
在这里插入图片描述

判别器最优解 $ D^*(x) $ 的推导

在生成对抗网络(GAN)中,判别器的目标是最大化以下目标函数:

V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] V(D, G) = \mathbb{E}_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x})}[\log D(\boldsymbol{x})] + \mathbb{E}_{\boldsymbol{z} \sim p_z(\boldsymbol{z})}[\log(1 - D(G(\boldsymbol{z})))] V(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

固定生成器 $ G $,我们希望找到使 $ V(D, G) $ 最大化的最优判别器 $ D^*(x) $。

一、简化目标函数

考虑对某个固定的输入样本 $ x $,我们可以将目标函数简化为一个关于 $ D(x) $ 的函数:

L ( D ( x ) ) = a log ⁡ D ( x ) + b log ⁡ ( 1 − D ( x ) ) L(D(x)) = a \log D(x) + b \log (1 - D(x)) L(D(x))=alogD(x)+blog(1D(x))

其中:

  • $ a = p_{data}(x) $:真实数据分布中 $ x $ 的概率密度;
  • $ b = p_g(x) $:生成器生成的数据分布中 $ x $ 的概率密度。

二、求极值:令导数为零

对 $ D(x) $ 求导并令其等于 0:

d L d D ( x ) = a D ( x ) − b 1 − D ( x ) = 0 \frac{dL}{dD(x)} = \frac{a}{D(x)} - \frac{b}{1 - D(x)} = 0 dD(x)dL=D(x)a1D(x)b=0

整理得:

a D ( x ) = b 1 − D ( x ) \frac{a}{D(x)} = \frac{b}{1 - D(x)} D(x)a=1D(x)b

这就是你看到的等式形式:

a D ∗ ( x ) = b 1 − D ∗ ( x ) \frac{a}{D^*(x)} = \frac{b}{1 - D^*(x)} D(x)a=1D(x)b

三、解方程求出 $ D^*(x) $

交叉相乘:

a ( 1 − D ∗ ( x ) ) = b D ∗ ( x ) a(1 - D^*(x)) = b D^*(x) a(1D(x))=bD(x)

展开并整理:

a = a D ∗ ( x ) + b D ∗ ( x ) ⇒ a = D ∗ ( x ) ( a + b ) a = a D^*(x) + b D^*(x) \Rightarrow a = D^*(x)(a + b) a=aD(x)+bD(x)a=D(x)(a+b)

解得:

D ∗ ( x ) = a a + b D^*(x) = \frac{a}{a + b} D(x)=a+ba

代入 $ a = p_{data}(x), b = p_g(x) $,得到最终结果:

D ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} D(x)=pdata(x)+pg(x)pdata(x)

四、意义

这表示在给定输入样本 $ x $ 的情况下,最优判别器 $ D^*(x) $ 输出的是该样本来自真实数据分布而非生成分布的概率。
在这里插入图片描述


在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82472.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式爬虫架构设计

随着互联网数据的爆炸式增长&#xff0c;单机爬虫已经难以满足大规模数据采集的需求。分布式爬虫应运而生&#xff0c;它通过多节点协作&#xff0c;实现了数据采集的高效性和容错性。本文将深入探讨分布式爬虫的架构设计&#xff0c;包括常见的架构模式、关键技术组件、完整项…

[java]eclipse中windowbuilder插件在线安装

目录 一、打开eclipse 二、打开插件市场 三、输入windowbuilder&#xff0c;点击install 四、进入安装界面 五、勾选我同意... 重启即可 一、打开eclipse 二、打开插件市场 三、输入windowbuilder&#xff0c;点击install 四、进入安装界面 五、勾选我同意... 重启即可

sass,less是什么?为什么要使用他们?

理解 他们都是css的预处理器,允许开发者通过更高级的语法编写css代码(支持变量,嵌套),然后通过编译成css文件 使用原因 结构清晰,便于扩展提高开发效率,便于后期开发维护

Java设计模式之模板方法模式:从基础到高级的全面解析(最详解)

文章目录 一、模板方法模式基础概念1.1 什么是模板方法模式1.2 模板方法模式的核心结构1.3 模板方法模式中的方法分类1.4 模板方法模式的简单示例二、模板方法模式的深入解析2.1 模板方法模式的核心原理2.2 模板方法模式的优势与适用场景优势分析适用场景2.3 模板方法模式与其他…

【C/C++】如何在一个事件驱动的生产者-消费者模型中使用观察者进行通知与解耦

文章目录 如何在一个事件驱动的生产者-消费者模型中使用观察者进行通知与解耦?1 假设场景设计2 Codes3 流程图4 优劣势5 风险可能 如何在一个事件驱动的生产者-消费者模型中使用观察者进行通知与解耦? 1 假设场景设计 Producer&#xff08;生产者&#xff09;&#xff1a;生…

MVC和MVVM架构的区别

MVC和MVVM都是前端开发中常用的设计模式&#xff0c;都是为了解决前端开发中的复杂性而设计的&#xff0c;而MVVM模式则是一种基于MVC模式的新模式。 MVC(Model-View-Controller)的三个核心部分&#xff1a;模型、视图、控制器相较于MVVM(Model-View-ViewModel)的三个核心部分…

兰亭妙微 | 图标设计公司 | UI设计案例复盘

在「33」「312」新高考模式下&#xff0c;选科决策成为高中生和家长的「头等大事」。兰亭妙微公司受委托优化高考选科决策平台个人诊断报告界面&#xff0c;核心挑战是&#xff1a;如何将复杂的测评数据&#xff08;如学习能力倾向、学科报考机会、职业兴趣等&#xff09;转化为…

有铜半孔的设计规范与材料创新

设计关键参数 孔径与间距限制 最小孔径需≥0.6mm&#xff0c;孔边距≥0.5mm&#xff0c;避免铜层脱落&#xff1b;拼版时半孔区域需预留2mm间距防止撕裂。 阻焊桥设计 必须保留阻焊桥&#xff08;宽度≥0.1mm&#xff09;&#xff0c;防止焊锡流入孔内造成短路。 猎板的材料…

Engineering a direct k-way Hypergraph Partitioning Algorithm【2017 ALENEX】

文章目录 一、作者二、摘要三、相关工作四、算法概述五、实验结果六、主要贡献 一、作者 Yaroslav Akhremtsev, Tobias Heuer, Peter Sanders, Sebastian Schlag 二、摘要 我们开发了一种快速且高质量的多层算法&#xff0c;能够直接将超图划分为 k 个平衡的块 —— 无需借助递…

视频问答功能播放器(视频问答)视频弹题功能实例

视频问答播放器是一种互动教学工具&#xff0c;在视频播放过程中弹出题目卡&#xff0c;学员答题后才能继续观看&#xff0c;提升学习参与度。视频问答功能播放器(视频问答)视频弹题功能实例&#xff1a; 视频播放器的视频问答功能&#xff08;也叫问答播放器、视频弹题、视频问…

2025年AI代理演进全景:从技术成熟度曲线到产业重构

2025年AI代理演进全景&#xff1a;从技术成熟度曲线到产业重构 一、技术成熟度曲线定位&#xff1a;AI代理的“期望膨胀期” 根据Gartner技术成熟度曲线&#xff08;Hype Cycle™&#xff09;&#xff0c;AI代理&#xff08;Agentic AI&#xff09;当前正处于期望膨胀期向泡沫…

基于python的机器学习(八)—— 评估算法(一)

目录 一、机器学习评估的基本概念 1.1 评估的定义与目标 1.2 常见评估指标 1.3 训练集、验证集与测试集的划分 二、分离数据集 2.1 分离训练数据集和评估数据集 2.2 k折交叉验证分离 2.3 弃一交叉验证分离 2.4 重复随机评估和训练数据集分离 三、交叉验证技术 3.…

Win11 系统登入时绑定微软邮箱导致用户名欠缺

Win11 系统登入时绑定微软邮箱导致用户名欠缺 解决思路 -> 解绑当前微软邮箱和用户名 -> 断网离线建立本地账户 -> 设置本地账户为Admin权限 -> 注销当前账户&#xff0c;登入新建的用户 -> 联网绑定微软邮箱 -> 删除旧的用户命令步骤 管理员权限打开…

Mac系统-最方便的一键环境部署软件ServBay(支持php,java,python,node,go,mysql等)没有之一,已亲自使用!

自从换成Mac电脑以后&#xff0c;做开发有时候要部署各种环境&#xff0c;如php&#xff0c;mysql&#xff0c;nginx&#xff0c;pgsql&#xff0c;java&#xff0c;node&#xff0c;python&#xff0c;go时&#xff0c;尝试过原生环境部署&#xff0c;各种第三方软件部署&…

Flink中Kafka连接器的基本应用

文章目录 前言Kafka连接器基础案例演示前置说明和环境准备步骤Kafka连接器基本配置关联数据源映射转换案例效果演示基于Kafka连接器同步数据到MySQL案例说明前置准备Kafka连接器消费位点调整映射转换与数据投递MysqlSlink持久化收集器数据最终效果演示小结参考前言 本文将基于…

Leetcode 刷题记录 11 —— 二叉树第二弹

本系列为笔者的 Leetcode 刷题记录&#xff0c;顺序为 Hot 100 题官方顺序&#xff0c;根据标签命名&#xff0c;记录笔者总结的做题思路&#xff0c;附部分代码解释和疑问解答&#xff0c;01~07为C语言&#xff0c;08及以后为Java语言。 01 二叉树的层序遍历 /*** Definition…

【R语言科研绘图】

R语言在绘制SCI期刊图像时具有显著优势&#xff0c;以下从功能、灵活性和学术适配性三个方面分析其适用性&#xff1a; 数据可视化库丰富 R语言拥有ggplot2、lattice、ggpubr等专业绘图包&#xff0c;支持生成符合SCI期刊要求的高分辨率图像&#xff08;如TIFF/PDF格式&#…

【Node.js】Web开发框架

个人主页&#xff1a;Guiat 归属专栏&#xff1a;node.js 文章目录 1. Node.js Web框架概述1.1 Web框架的作用1.2 Node.js主要Web框架生态1.3 框架选择考虑因素 2. Express.js2.1 Express.js概述2.2 基本用法2.2.1 安装Express2.2.2 创建基本服务器 2.3 路由2.4 中间件2.5 请求…

PDF 转 JPG 图片小工具:CodeBuddy 助力解决转换痛点

本文所使用的 CodeBuddy 免费下载链接&#xff1a;腾讯云代码助手 CodeBuddy - AI 时代的智能编程伙伴 前言 在数字化办公与内容创作的浪潮中&#xff0c;将 PDF 文件转换为 JPG 图片格式的需求日益频繁。无论是学术文献中的图表提取&#xff0c;还是宣传资料的视觉化呈现&am…

Linux 文件系统层次结构

Linux 的文件系统遵循 Filesystem Hierarchy Standard (FHS) 标准&#xff0c;其目录结构是层次化的&#xff0c;每个目录都有明确的用途。以下是 Linux 中部分目录的作用解析&#xff1a; 1. 根目录 / 作用&#xff1a;根目录是整个文件系统的顶层目录&#xff0c;所有其他目…