第40周——GAN入门

目录

目录

目录

前言

一、定义超参数

二、下载数据

三、配置数据

四、定义鉴别器

五、训练模型并保存

总结


前言

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

一、定义超参数

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch## 创建文件夹
os.makedirs("./images/", exist_ok=True)         # 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           # 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下载数据集存放的位置## 超参数配置
n_epochs  = 50
batch_size= 64
lr        = 0.0002
b1        = 0.5
b2        = 0.999
n_cpu     = 2
latent_dim= 100
img_size  = 28
channels  = 1
sample_interval=500# 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

二、下载数据

# mnist数据集下载
mnist = datasets.MNIST(root='./datasets/', train=True, download=True, transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), 
)


三、配置数据

# 配置数据到加载器
dataloader = DataLoader(mnist,batch_size=batch_size,shuffle=True,
)

四、定义鉴别器

# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_area, 512),         # 输入特征数为784,输出为512nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(512, 256),              # 输入特征数为512,输出为256nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射nn.Linear(256, 1),                # 输入特征数为256,输出为1nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数)def forward(self, img):img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)validity = self.model(img_flat)      # 通过鉴别器网络return validity       

五、训练模型并保存

## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():generator     = generator.cuda()discriminator = discriminator.cuda()criterion     = criterion.cuda()## 进行多个epoch的训练
for epoch in range(n_epochs):                   # epoch:50for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)## =============================训练判别器==================## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0## ---------------------##  Train Discriminator## 分为两部分:1、真的图像判别为真;2、假的图像判别为假## ---------------------## 计算真实图片的损失real_out = discriminator(real_img)            # 将真实图片放入判别器中loss_real_D = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好## 计算假的图片的损失## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。 fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的lossfake_scores = fake_out                                              ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好## 损失函数和优化loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0loss_D.backward()                   # 将误差反向传播optimizer_D.step()                  # 更新参数## -----------------##  Train Generator## 原理:目的是希望生成的假的图片被判别器判断为真的图片,## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,## 反向传播更新的参数是生成网络里面的参数,## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的## -----------------z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片output = discriminator(fake_img)                                    ## 经过判别器得到的结果## 损失函数和优化loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的lossoptimizer_G.zero_grad()                                             ## 梯度归0loss_G.backward()                                                   ## 进行反向传播optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数## 打印训练过程中的日志## item():取出单元素张量的元素值并返回该值,保持原元素类型不变if (i + 1) % 300 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"% (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))## 保存训练过程中的图像batches_done = epoch * len(dataloader) + iif batches_done % sample_interval == 0:save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)## 保存模型
torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')


总结:

本项目中,我们实现了一个基于 MNIST 数据集的生成对抗网络(GAN),主要流程从参数配置、数据准备,到模型构建与训练,最后再到结果保存,形成了一个完整的生成式模型训练管线。

首先,在超参数设置上,我们采用了经典 GAN 论文中推荐的组合:学习率设为 0.0002,Adam 优化器的 \beta_1 和 \beta_2 分别为 0.5 和 0.999,训练 50 个周期,批量大小为 64。这些参数能在保证稳定训练的同时,加快收敛速度。

数据部分选用了MNIST 手写数字集,先将像素归一化到 [-1, 1],再通过 DataLoader 按批次读取并打乱顺序。这一处理不仅保证了输入分布的稳定性,也提升了训练效率。

在模型结构方面,判别器(D)是一个多层全连接网络,将 28×28 图像展平为 784 维向量后输入,激活函数使用 LeakyReLU,最后通过 Sigmoid 得到真假概率;生成器(G)则以 100 维随机噪声为输入,经过多层全连接与 ReLU/Tanh 激活,输出与真实图像同尺寸的 28×28 结果。这样的结构简单直观,适合入门实验。

训练时,判别器与生成器交替优化:

  • 判别器的目标是最大化对真实图像的判真概率、对生成图像的判假概率;

  • 生成器的目标则是让判别器将其生成的图像判为真。

    损失函数统一采用二分类交叉熵(BCELoss),并为 D、G 分别设置优化器以避免梯度更新冲突。训练过程中会定期输出损失值并保存生成样本,方便对生成效果进行直观评估。

GAN 的核心思想,是让生成器与判别器在对抗博弈中共同提升能力:G 学会捕捉数据分布特征,D 学会分辨真实与伪造,两者在动态平衡中逼近真实分布。这种机制使得 GAN 特别适合用于图像生成、数据增强和风格迁移等任务。

从实验体验来看,这套代码的优点在于结构清晰、可视化直观、参数稳定,非常适合作为学习 GAN 的起点。同时,经过适当修改,还能扩展到更复杂的生成任务,为后续的研究打下基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93250.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93250.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx性能优化与安全配置:打造高性能Web服务器

系列文章索引: 第一篇:《Nginx入门与安装详解:从零开始搭建高性能Web服务器》第二篇:《Nginx基础配置详解:nginx.conf核心配置与虚拟主机实战》第三篇:《Nginx代理配置详解:正向代理与反向代理…

二分算法(模板)

例题1: 704. 二分查找 - 力扣(LeetCode) 算法原理:(二分) 通过遍历也可以通过,但是二分更优且数据量越大越能体现。 二分思路: 1.mid1 (left right)/2 与 mid2 right (right …

VUE3 学习笔记2 computed、watch、生命周期、hooks、其他组合式API

computed 计算属性在vue3中,虽然也能写vue2的computed,但还是更推荐使用vue3语法的computed。在Vue3中,计算属性是组合式API,要想使用computed,需要先对computed进行引入:import { computed } from vuecomp…

【java面试day13】mysql-定位慢查询

文章目录问题💬 Question 1相关知识问题 💬 Question 1 Q:这条sql语句执行很慢,你如何分析呢? A:当一条 SQL 执行较慢时,可以先使用 EXPLAIN 查看执行计划,通过 key 和 key_len 判…

3分钟解锁网页“硬盘“能力:离线运行VSCode的新一代Web存储技术

Hi,我是前端人类学(之前叫布兰妮甜)! “这不是浏览器,这是装了个硬盘。” —— 用户对现代Web应用能力的惊叹 随着Origin Private File System和IndexedDB Stream等新技术的出现,Web应用现在可以在用户的设…

LT6911GXD,HD-DVI2.1/DP1.4a/Type-C 转 Dual-port MIPI/LVDS with Audio 带音频

简介LT6911GXD是一款高性能HD-DVI2.1/DP1.4a/Type-c转Dual-port MIPI/LVDS芯片,兼容 HDMI2.1、HDMI2.0b、HDMI1.4、DVI1.0、DisplayPort 1.4a、eDP1.4b 等多种视频接口标准。支持4K(38402160)60Hz的DSC直通。应用场景AR/VR设备LT6911GXD 支持高达 4K(384…

【100页PPT】数字化转型某著名企业集团信息化顶层规划方案(附下载方式)

篇幅所限,本文只提供部分资料内容,完整资料请看下面链接 https://download.csdn.net/download/2501_92808811/91662628 资料解读:数字化转型某著名企业集团信息化顶层规划方案 详细资料请看本解读文章的最后内容 作为企业数字化转型领域的…

高精度标准钢卷尺优质厂家、选购建议

高精度标准钢卷尺的优质厂家通常具备精湛工艺与权威精度认证等特征,能为产品质量提供保障。其选购需兼顾精度标识、使用场景、结构细节等多方面,具体介绍如下:一、高精度标准钢卷尺优质厂家**1、河南普天同创:**PTTC-C5标准钢卷尺…

38 C++ STL模板库7-迭代器

C STL模板库7-迭代器 文章目录C STL模板库7-迭代器一、迭代器的核心作用二、迭代器的五大分类与操作三、关键用法与代码示例1. 迭代器的原理2. 迭代器用法与示例3. 迭代工具用法示例4. 使用技巧迭代器是C中连接容器与算法的通用接口,提供了一种访问容器元素的统一方…

【0基础3ds Max】学习计划

3ds Max 作为一款功能强大的专业 3D 计算机图形软件,在影视动画、游戏开发、建筑可视化、产品设计和工业设计等众多领域有着广泛的应用。 目录前言一、第一阶段:基础认知(第 1 - 2 周)​二、第二阶段:建模技术学习&…

用 Enigma Virtual Box 将 Qt 程序打包成单 exe

上一篇介绍了用windeployqt生成可运行的多文件程序,但一堆文件分发起来不够方便。有没有办法将所有文件合并成一个 exe? 答案是肯定的 用Enigma Virtual Box工具就能实现。本文就来讲解如何用它将 Qt 多文件程序打包为单一 exe,让分发更轻松。 其中的 一定要选 第二个 一…

【LeetCode 热题 100】45. 跳跃游戏 II

Problem: 45. 跳跃游戏 II 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向后跳转的最大长度。换句话说&#xff0c;如果你在索引 i 处&#xff0c;你可以跳转到任意 (i j) 处&#xff1a; 0 < j < nums[i] 且 i j &…

池式管理之线程池

1.初识线程池问&#xff1a;线程池是什么&#xff1f;答&#xff1a;维持管理一定数量的线程的池式结构。&#xff08;维持&#xff1a;线程复用 。 管理&#xff1a;没有收到任务的线程处于阻塞休眠状态不参与cpu调度 。一定数量&#xff1a;数量太多的线程会给操作系统带来线…

婴儿 3D 安睡系统专利拆解:搭扣与智能系带的锁定机制及松紧调节原理

凌晨2点&#xff0c;你盯着婴儿床里的小肉团直叹气。刚用襁褓裹成小粽子才哄睡的宝宝&#xff0c;才半小时就蹬开了裹布&#xff0c;小胳膊支棱得像只小考拉&#xff1b;你手忙脚乱想重新裹紧&#xff0c;结果越裹越松&#xff0c;裹布滑到脖子边&#xff0c;宝宝突然一个翻身&…

pandas中df.to _dict(orient=‘records‘)方法的作用和场景说明

df.to _dict(orientrecords) 是 Pandas DataFrame 的一个方法&#xff0c;用于将数据转换为字典列表格式。以下是详细解释及实例说明&#xff1a; 一、核心含义作用 将 DataFrame 的每一行转换为一个字典&#xff0c;所有字典组成一个列表。 每个字典的键&#xff08;key&#…

阿里云Anolis OS 8.6的公有云仓库源配置步骤

文章目录一、备份现有仓库配置&#xff08;防止误操作&#xff09;二、配置阿里云镜像源2.1 修改 BaseOS 仓库2.2 修改 AppStream 仓库三、清理并重建缓存四、验证配置4.1 ​检查仓库状态​&#xff1a;五、常见问题解决5.1 ​HTTP 404 错误5.2 ​网络连接问题附&#xff1a;其…

回归预测 | Matlab实现CNN-BiLSTM-self-Attention多变量回归预测

回归预测 | Matlab实现CNN-BiLSTM-self-Attention多变量回归预测 目录回归预测 | Matlab实现CNN-BiLSTM-self-Attention多变量回归预测预测效果基本介绍程序设计参考资料预测效果 基本介绍 1.Matlab实现CNN-BiLSTM融合自注意力机制多变量回归预测&#xff0c;CNN-BiLSTM-self-…

103、【OS】【Nuttx】【周边】文档构建渲染:Sphinx 配置文件

【声明】本博客所有内容均为个人业余时间创作&#xff0c;所述技术案例均来自公开开源项目&#xff08;如Github&#xff0c;Apache基金会&#xff09;&#xff0c;不涉及任何企业机密或未公开技术&#xff0c;如有侵权请联系删除 背景 接之前 blog 【OS】【Nuttx】【周边】文…

转换一个python项目到moonbit,碰到报错输出:编译器对workflow.mbt文件中的类方法要求不一致的类型注解,导致无法正常编译

先上结论&#xff1a;现在是moon test的时候有很多报错&#xff0c;消不掉。问题在Trae中用GLM-4.5模型&#xff0c;转换一个python项目到moonbit&#xff0c;碰到报错输出&#xff1a;报错输出经过多次尝试修复&#xff0c;我发现这是一个MoonBit编译器的bug。编译器对workflo…

【C#补全计划】事件

一、事件的概念1. 事件是基于委托的存在&#xff0c;是委托的安全包裹&#xff0c;让委托的使用更具有安全性2. 事件是一种特殊的变量类型二、事件的使用1. 语法&#xff1a;event 委托类型 事件名;2. 使用&#xff1a;&#xff08;1&#xff09;事件是作为成员变量存在与类中&…