day54 python对抗生成网络

目录

一、GAN对抗生成网络思想

二、实践过程

1. 数据准备

2. 构建生成器和判别器

3. 训练过程

4. 生成结果与可视化

三、学习总结


一、GAN对抗生成网络思想

GAN的核心思想非常有趣且富有对抗性。它由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是从随机噪声中生成尽可能接近真实数据的样本,而判别器的任务则是区分生成器生成的假样本和真实样本。这两个网络在训练过程中相互对抗,生成器不断改进生成的样本以欺骗判别器,判别器则不断提升自己的辨别能力。最终,当生成器生成的样本足够逼真,以至于判别器难以区分真假时,GAN达到了一种平衡状态。

从数学角度来看,GAN的损失函数由两部分组成:生成器的损失和判别器的损失。判别器的损失是一个二分类问题的损失,通常使用二元交叉熵损失(BCELoss)。生成器的损失则依赖于判别器的反馈,目标是让判别器将生成的样本误判为真实样本。这种对抗机制使得GAN能够生成高质量的样本,尤其是在图像生成领域。

二、实践过程

为了更好地理解GAN的工作原理,我使用了Python和PyTorch框架实现了一个简单的GAN模型。以下是我的实践过程和代码实现。

1. 数据准备

我选择了经典的鸢尾花(Iris)数据集中的“Setosa”类别作为实验对象。这个数据集包含4个特征,非常适合用来测试GAN模型。我首先对数据进行了归一化处理,将其缩放到[-1, 1]范围内,以提高模型的训练效果。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt# 加载数据
iris = load_iris()
X = iris.data
y = iris.target# 选择 Setosa 类别
X_class0 = X[y == 0]# 数据归一化
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X_class0)# 转换为 PyTorch Tensor
real_data_tensor = torch.from_numpy(X_scaled).float()
dataset = TensorDataset(real_data_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2. 构建生成器和判别器

接下来,我定义了生成器和判别器的网络结构。生成器使用了简单的多层感知机(MLP)结构,输入是随机噪声,输出是与真实数据维度相同的样本。判别器同样使用MLP结构,输出是一个概率值,表示输入样本是真实样本的概率。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(10, 16),nn.ReLU(),nn.Linear(16, 32),nn.ReLU(),nn.Linear(32, 4),nn.Tanh())def forward(self, x):return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(4, 32),nn.LeakyReLU(0.2),nn.Linear(32, 16),nn.LeakyReLU(0.2),nn.Linear(16, 1),nn.Sigmoid())def forward(self, x):return self.model(x)

3. 训练过程

在训练过程中,我交替更新生成器和判别器的参数。每一步中,首先用真实数据和生成数据训练判别器,然后用生成数据训练生成器。通过这种方式,两个网络不断对抗,逐渐提升性能。

# 定义损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))# 训练循环
for epoch in range(10000):for i, (real_data,) in enumerate(dataloader):# 训练判别器d_optimizer.zero_grad()real_output = discriminator(real_data)d_loss_real = criterion(real_output, torch.ones_like(real_output))noise = torch.randn(real_data.size(0), 10)fake_data = generator(noise).detach()fake_output = discriminator(fake_data)d_loss_fake = criterion(fake_output, torch.zeros_like(fake_output))d_loss = d_loss_real + d_loss_faked_loss.backward()d_optimizer.step()# 训练生成器g_optimizer.zero_grad()fake_data = generator(noise)fake_output = discriminator(fake_data)g_loss = criterion(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optimizer.step()if (epoch + 1) % 1000 == 0:print(f"Epoch [{epoch+1}/10000], Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")

4. 生成结果与可视化

训练完成后,我使用生成器生成了一些新的样本,并将它们与真实样本进行了可视化对比。从结果可以看出,生成器生成的样本在分布上与真实样本较为接近,说明GAN模型在一定程度上成功地学习了数据的分布。

# 生成新样本
with torch.no_grad():noise = torch.randn(50, 10)generated_data_scaled = generator(noise)# 逆向转换回原始尺度
generated_data = scaler.inverse_transform(generated_data_scaled.numpy())
real_data_original_scale = scaler.inverse_transform(X_scaled)# 可视化对比
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('真实数据 vs. GAN生成数据 的特征分布对比', fontsize=16)
feature_names = iris.feature_namesfor i, ax in enumerate(axes.flatten()):ax.hist(real_data_original_scale[:, i], bins=10, density=True, alpha=0.6, label='Real Data')ax.hist(generated_data[:, i], bins=10, density=True, alpha=0.6, label='Generated Data')ax.set_title(feature_names[i])ax.legend()plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

三、学习总结

通过这次实践,我对GAN的工作原理有了更深入的理解。GAN的核心在于生成器和判别器的对抗机制,这种机制使得模型能够生成高质量的样本。在实际应用中,GAN不仅可以用于图像生成,还可以用于数据增强、风格迁移等任务。

然而,GAN的训练过程也存在一些挑战。例如,生成器和判别器的平衡很难把握,如果其中一个网络过于强大,可能会导致训练失败。此外,GAN的训练过程通常需要大量的计算资源和时间。

在未来的学习中,我计划探索更多GAN的变体,如WGAN、DCGAN等,以更好地理解和应用生成对抗网络。同时,我也希望能够将GAN应用于更复杂的任务中,例如图像生成和视频生成,进一步提升我的深度学习技能。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/84725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

龙虎榜——20250613

上证指数放量下跌收阴线,个股下跌超4000只,受外围消息影响情绪总体较差。 深证指数放量下跌,收阴线,6月总体外围风险较高,转下跌走势的概率较大,注意风险。 2025年6月13日龙虎榜行业方向分析 1. 石油石化&…

Linux常用命令加强版替代品

Linux常用命令加强版替代品 还在日复一日地使用 ls、grep、cd 这些“上古”命令吗?是时候给你的终端来一次大升级了!本文将为你介绍一系列强大、高效且设计现代的Linux命令行工具,它们将彻底改变你的工作流,让你爱上在终端里操作…

Hadoop 003 — JAVA操作MapReduce入门案例

MapReduce入门案例-分词统计 文章目录 MapReduce入门案例-分词统计1.xml依赖2.编写MapReduce处理逻辑3.上传统计文件到HDFS3.配置MapReduce作业并测试4.执行结果 1.xml依赖 <dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-commo…

Python打卡第53天

浙大疏锦行 作业&#xff1a; 对于心脏病数据集&#xff0c;对于病人这个不平衡的样本用GAN来学习并生成病人样本&#xff0c;观察不用GAN和用GAN的F1分数差异。 import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from…

力扣-279.完全平方数

题目描述 给你一个整数 n &#xff0c;返回 和为 n 的完全平方数的最少数量 。 完全平方数 是一个整数&#xff0c;其值等于另一个整数的平方&#xff1b;换句话说&#xff0c;其值等于一个整数自乘的积。例如&#xff0c;1、4、9 和 16 都是完全平方数&#xff0c;而 3 和 1…

前端构建工具Webapck、Vite——>前沿字节开源Rspack详解——2023D2大会

Rspack 以下是针对主流构建工具&#xff08;Webpack、Vite、Rollup、esbuild&#xff09;的核心不足分析&#xff0c;以及 Rspack 如何基于这些痛点进行针对性改进 的深度解析&#xff1a; 一、主流构建工具的不足 1. Webpack&#xff1a;性能与生态的失衡 核心问题 冷启动慢…

输入法,开头输入这U I V 三个字母会不显示 任何中文

1. 汉语拼音规则的限制 汉语拼音中不存在以“V”“U”“I”为声母的情况&#xff1a; 汉语拼音的声母是辅音&#xff0c;而“V”“U”“I”在汉语拼音中都是元音&#xff08;或韵母的一部分&#xff09;。汉语拼音的声母系统中没有“V”“U”“I”作为声母的音节。例如&#xf…

Linux文件权限详解:从入门到精通

前言 权限是什么&#xff1f; 本质&#xff1a;无非就是能做和不能做什么。 为什么要有权限呢&#xff1f; 目的&#xff1a;为了控制用户行为&#xff0c;防止发生错误。 1.权限的理解 在学习下面知识之前要先知道的一点是&#xff1a;linux下一切皆文件&#xff0c;对li…

在多云环境透析连接ngx_stream_proxy_protocol_vendor_module

1、模块定位与价值 多云接入&#xff1a;在同一 Nginx 实例前端接入来自多云平台的私有链路时&#xff0c;能区分 AWS、GCP、Azure 特有的连接 ID。安全审计&#xff1a;自动记录云平台侧的 Endpoint/VPC ID&#xff0c;有助于联调和安全事件追踪。路由分流&#xff1a;基于不…

力扣:基本计算器

基本计算器: 224. 基本计算器 - 力扣&#xff08;LeetCode&#xff09; 本体思路为&#xff0c;将中缀表达式转为后缀表达式&#xff0c;通过后缀表达式进行运算。 中缀表达式: 我们日常生活中熟知的表达式如12-30 就是一个中缀表达式。 后缀表达式: 150. 逆波兰表达式求值 - …

《AI日报 · 0613|ChatGPT支持导出、Manus免费开放、GCP全球宕机》

AI 资讯 1️⃣ OpenAI ChatGPT Canvas新增多格式导出功能 OpenAI终于为ChatGPT Canvas推出了用户期待已久的导出功能。现在,用户可以将创作内容导出为多种格式:文档类支持PDF、docx和markdown格式,代码文件则可直接保存为对应扩展名的源文件(如.py、.js、.sql等)。这一功…

C++中的零拷贝技术

一、C中零拷贝技术的核心概念 零拷贝&#xff08;Zero-copy&#xff09;是一种重要的优化技术&#xff0c;旨在减少数据在内存中的不必要复制&#xff0c;从而提高程序性能、降低内存使用并减少CPU消耗。在C中&#xff0c;零拷贝技术通过多种方式实现&#xff0c;包括引用语义…

RT_Thread内核源码分析(五)——内存管理@小堆内存管理算法

目录 1、内存堆控制 1.1 内存堆控制器 1.2 内存块节点 1.3 内存堆管理 2、内存堆初始化 2.1 初始化接口 2.2 初始化示例 2.3 源码分析 3、内存堆操作 3.1 内存块申请 3.1.1 相关接口 3.1.2 原理分析 3.1.3 示例分析 3.1.4 代码分析 3.2 内存块伸缩 3.2.1 相关…

MyBatis-Plus 混合使用 XML 和注解

mybatisplus代码生成器&#xff1a; 版本匹配是个比较麻烦的问题&#xff0c;这是我的配置&#xff1a; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.2</version>…

基于ssm的教学质量评估系统

博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;熟悉各种主流语言&#xff0c;精通java、python、php、爬虫、web开发&#xff0c;已经做了六年的毕业设计程序开发&#xff0c;开发过上千套毕业设计程序&#xff0c;没有什么华丽的语言&#xff0…

【STM32】G030单片机开启超过8个ADC通道的方法

如图所示通道数量已经超过8个&#xff0c;按照之前博客的办法已经行不通了 CubeMX配置STM32F103C8T6多路ADC配合DMA采集_stm32f103c8t6的adc采样率-CSDN博客 这里笔者开了10个channel&#xff0c;注意切换为不完全配置&#xff0c;否则的话最多只有8个rank 开DMA&#xff0c;…

不同网络I/O模型的原理

目录 1、I/O的介绍 1.1、I/O 操作分类 1.2、I/O操作流程阶段 1.3、I/O分类 2、同步I/O 2.1、阻塞I/O 2.2、非阻塞I/O 2.3、I/O复用 2.4、信号驱动式I/O 3、异步I/O 前言 在网络I/O之中&#xff0c;I/O操作往往会涉及到两个系统对象&#xff0c;一个是用户空间调用I/O…

在正则表达式中语法 (?P<名字>内容)

&#x1f3af; 重点解释&#xff1a;?P<xxx> 是什么语法&#xff1f; 这一整段&#xff1a; (?P<xxx>...)是 Python 正则表达式中 “命名捕获组” 的语法。 咱们现在一个字一个字来解释&#xff1a; ✅ (?...) 是干啥的&#xff1f; 这是一个捕获组&#xff…

中兴B860AV1.1_MSO9280_降级后开ADB-免刷机破解教程(非刷机)

中兴B860AV1.1江苏移动-自动降级包 关于中兴b860av1.1顽固盒子降级教程终极版 将附件解压好以后&#xff0c;准备一个8G以下的U盘重新格式化为FAT32格式后&#xff0c;并插入电脑 将以下文件及文件夹一同复制到优盘主目录下&#xff08;见下图&#xff09; 全选并复制到U盘主目…

2025-06-13【视频处理】基于视频内容转场进行分割

问题&#xff1a;从网上下载的视频文件&#xff0c;是由很多个各种不同的场景视频片段合并而成。现在要求精确的把各个视频片段从大视频里分割出来。 效果如图&#xff1a;已分割出来的小片段 思考过程 难点在于检测场景变化。为什么呢&#xff1f;因为不同的视频情况各异&am…