深度学习G3周:CGAN入门(生成手势图像)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

基础任务:

1.条件生成对抗网络(CGAN)的基本原理

2.CGAN是如何实现条件控制的

3.学习本文CGAN代码,并跑通代码

进阶任务:

生成指定手势的图像

一、理论知识

条件生成对抗网络(CGAN):在生成对抗网络(GAN)的基础上进行了一些改进。对于原始GAN的生成器而言,其生成的图像数据是随机不可预测的,因此,我们无法控制网络的输出,在实际操作中的可控性不强。

针对上述原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了条件生成对抗网络,通过给原始生成对抗网络中的生成器G和判别器D增加额外的条件,如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。

条件生成对抗网络的本质:将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能在我们的掌控下更好地进行训练。

CGAN网络结构图:

由图可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用。

二、准备工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt
import datetimetorch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

1.导入数据

train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root="D:/study/data/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)def show_images(images):fig, ax = plt.subplots(figsize=(20, 20))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))def show_batch(dl):for images, _ in dl:show_images(images)breakshow_batch(train_loader)

遇到问题:

多线程缘故---去掉num_workers参数

image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100n_classes = 3
embedding_dim = 100

三、构建模型

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)

1.构建生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 16)            )self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  nn.LeakyReLU(0.2, inplace=True)  )self.model = nn.Sequential( nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  nn.ReLU(True),            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),     nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  )def forward(self, inputs):noise_vector, label = inputs  label_output = self.label_conditioned_generator(label)     label_output = label_output.view(-1, 1, 4, 4)        latent_output = self.latent(noise_vector)     latent_output = latent_output.view(-1, 512, 4, 4) concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

输出:

from torchinfo import summary
summary(generator)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

输出:

 

2.构建鉴别器

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),     nn.Linear(embedding_dim, 3*128*128)         )self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),      nn.LeakyReLU(0.2, inplace=True),             nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),                               nn.Dropout(0.4),                            nn.Linear(4608, 1),                         nn.Sigmoid()                                )def forward(self, inputs):img, label = inputslabel_output = self.label_condition_disc(label)label_output = label_output.view(-1, 3, 128, 128)concat = torch.cat((img, label_output), dim=1)output = self.model(concat)return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)
summary(discriminator)
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()

输出:

 

四、训练模型

1.定义损失函数

adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss

2.定义优化器

learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3.训练模型

代码逻辑结构图:

1.首先设置了训练的总轮数和用于存储每轮训练中判别器和生成器损失的列表

2.然后进行GAN模型的训练。在每轮训练中,它首先从数据加载器中加载真实图像和标签,然后计算判别器对真实图像的损失,接着从噪声向量中生成假图像,计算判别器对假图像的损失,计算判别器总体损失并反向传播更新判别器参数,然后计算生成器的损失并反向传播更新生成器的参数

3.最后,它打印当前轮次的判别器和生成器的平均损失,并将当前轮次的判别器和生成器的平均损失保存到列表中

4.在每10轮训练后,它会将生成的假图像保存为图片文件,并将当前轮次的生成器和判别器的权重保存到文件

num_epochs = 100D_loss_plot, G_loss_plot = [], []for epoch in range(1, num_epochs + 1):D_loss_list, G_loss_list = [], []for index, (real_images, labels) in enumerate(train_loader):D_optimizer.zero_grad()real_images = real_images.to(device)labels      = labels.to(device)labels = labels.unsqueeze(1).long()real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)D_total_loss.backward()D_optimizer.step()G_optimizer.zero_grad()G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)G_loss.backward()G_optimizer.step()print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出:

 

五、模型分析

1.加载模型

generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()               
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspecdef generate_latent_points(latent_dim, n_samples, n_classes=3):x_input = randn(latent_dim * n_samples)z_input = x_input.reshape(n_samples, latent_dim)return z_inputdef interpolate_points(p1, p2, n_steps=10):ratios = linspace(0, 1, num=n_steps)vectors = list()for ratio in ratios:v = (1.0 - ratio) * p1 + ratio * p2vectors.append(v)return asarray(vectors)pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)output = None
for label in range(3):labels = torch.ones(10) * labellabels = labels.to(device)labels = labels.unsqueeze(1).long()print(labels.size())predictions = generator((interpolated, labels))predictions = predictions.permute(0,2,3,1)pred = predictions.detach().cpu()if output is None:output = predelse:output = np.concatenate((output,pred))
output.shape
nrow = 3
ncol = 10fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol) k = 0
for i in range(nrow):for j in range(ncol):pred = (output[k, :, :, :] + 1 ) * 127.5pred = np.array(pred)  ax= plt.subplot(gs[i,j])ax.imshow(pred.astype(np.uint8))ax.set_xticklabels([])ax.set_yticklabels([])ax.axis('off')k += 1   plt.show()

五、总结

学习了条件生成对抗网络的基本原理和代码。了解CGAN是怎么实现条件控制。上次遇到的问题这次又忘记了,还是得多看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89287.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89287.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

流式数据处理实战:用状态机 + scan 优雅过滤 AI 响应中的 `<think>` 标签

流式数据处理实战&#xff1a;用状态机 scan 优雅过滤 AI 响应中的 <think> 标签 1. 引言&#xff1a;流式数据处理的挑战 在现代 AI 应用开发中&#xff0c;流式 API&#xff08;如 OpenAI、Claude 等&#xff09;能实时返回分块数据&#xff0c;提升用户体验。但流式…

【实时Linux实战系列】硬件中断与实时性

在实时系统中&#xff0c;硬件中断是系统响应外部事件的关键机制之一。硬件中断允许系统在执行任务时被外部事件打断&#xff0c;从而快速响应这些事件。然而&#xff0c;中断处理不当可能会导致系统延迟增加&#xff0c;影响系统的实时性。因此&#xff0c;优化中断处理对于提…

基于DTLC-AEC与DTLN的轻量级实时语音降噪系统设计与实现

基于DTLC-AEC与DTLN的轻量级实时语音降噪系统设计与实现 1. 引言 在当今的实时通信应用中,语音质量是影响用户体验的关键因素之一。环境噪声和回声会严重降低语音清晰度,特别是在移动设备和嵌入式系统上。本文将详细介绍如何将两种先进的开源模型——DTLC-AEC(深度学习回声…

基于Hadoop与LightFM的美妆推荐系统设计与实现

文章目录有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主项目介绍总结每文一语有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主 项目介绍 本项目旨在基于大数据Hadoop平台和机器学习技术&#xff0c;构建一套面向美妆…

notepad++ 多行复制拼接

如何将中文一 一复制到英文后面按住 ALT ,鼠标左键拖动多行选中中文Ctrl C 复制 在英文的第一行结尾处 Ctrl v 粘贴

【前沿技术动态】【AI总结】Spring Boot 4.0 预览版深度解析:云原生时代的新里程碑

Spring Boot 4.0 预览版深度解析&#xff1a;云原生时代的新里程碑 最低 Java 17&#xff0c;原生支持虚拟线程&#xff0c;性能提升最高800%&#xff0c;Spring Boot 4.0 带来开发体验与运行时性能的全面飞跃 Spring Boot 4.0 的预览版在2025年5月底悄然上线&#xff0c;标志着…

OkHttp 框架封装一个 HTTP 客户端,用于调用外部服务接口

✅ 背景与需求 需要基于 OkHttp 框架封装一个 HTTP 客户端&#xff0c;用于调用外部服务接口&#xff08;如拼团回调&#xff09;&#xff0c;实现以下功能&#xff1a; 动态传入请求地址&#xff08;URL&#xff09;支持 JSON 请求体实现类放在 infrastructure 层的 gateway…

使用Collections.max比较Map<String, Integer>中的最大值

文章目录使用Collections.max比较Map<String, Integer>中的最大值基本方法1. 比较Map的值2. 比较Map的键自定义比较器1. 按值降序排列2. 复杂比较逻辑完整示例代码性能考虑替代方案1. 使用Stream API (Java 8)2. 手动遍历实际应用场景注意事项总结使用Collections.max比较…

鸿蒙状态栏操作

1.鸿蒙设备基础信息 1.1图解 1.1窗口内容规避区域 AvoidArea7 窗口内容规避区域。 窗口内容规避区域。如系统栏区域、刘海屏区域、手势区域、软键盘区域等与窗口内容重叠时&#xff0c;需要窗口内容避让的区域。在规避区无法响应用户点击事件。 除此之外还需注意规避区域的如…

Product Hunt 每日热榜 | 2025-07-17

1. Brain MAX by ClickUp 标语&#xff1a;一款AI应用统治一切&#xff1a;你的知识 语音转文字 介绍&#xff1a;Brain MAX 是 ClickUp 完全原生的桌面应用&#xff0c;旨在提升生产力&#xff0c;帮助你摆脱 AI 的杂乱无章。只需每月 9 美元&#xff0c;就可以使用所有的 …

如何使用VScode使用ssh连接远程服务器不需要输入密码直接登录

ssh-keygen 之后一直默认 回车 确认即可结果 (base) amaxamax:/data/std$ ssh-keygen Generating public/private rsa key pair. Enter file in which to save the key (/home/amax/.ssh/id_rsa): Enter passphrase (empty for no passphrase): Enter same passphrase again:…

vue实现el-table-column中自定义label

vue实现el-table-column中自定义label<el-table-columnlabel"操作"align"left"width"50"><template #header><div><el-buttonsize"mini"type"primary"icon"el-icon-plus"circle></el-…

Vue 常用的 ESLint 规则集

对Vue项目来说&#xff0c;Vue 官方通过 eslint-plugin-vue 提供了多个规则集&#xff08;Rule Sets&#xff09;&#xff0c;适用于不同严格度和 Vue 版本。以下是主要的规则集及其对应的 ESLint 插件和用途&#xff1a; 1. Vue 2.x 规则集 适用于 Vue 2 项目&#xff0c;规则…

AbMole小课堂 | Angiotensin II(血管紧张素Ⅱ)在心血管研究中的多元应用

Angiotensin II&#xff08;血管紧张素Ⅱ&#xff0c;AbMole&#xff0c;M6240&#xff09;是一种血管收缩剂&#xff0c;也是肾素-血管紧张素系统 (RAS) 的主要效应肽。Angiotensin II参与动物的血压调节、水电解质平衡等经典生理过程在科研中Angiotensin II被广泛用于动物心血…

【Unity】Mono相关理论知识学习

一种编译技术。优点&#xff1a;支持JIT编译&#xff1a;在运行时将IL编译成机器码。首次执行稍慢&#xff0c;好处在于运行更快&#xff0c;迭代更高效。构建速度快&#xff1a;无需将IL转成C&#xff0c;构建过程省去了IL2CPP的转换和原生编译步骤&#xff0c;适合开发阶段快…

React源码4 三大核心模块之一:Schedule,scheduleUpdateOnFiber函数

scheduler工作阶段在React内部被称为schedule阶段。在《React源码3》&#xff0c;我们已经将update加入队列并返回到了根容器节点root。function updateContainer(element, container, parentComponent, callback) {//前面略过var root enqueueUpdate(current$1, update, lane…

Unity3D + VS2022连接雷电模拟器调试

本文参考了Unity3D Profiler 连接真机和模拟器_unity 连接雷电模拟器-CSDN博客 具体步骤&#xff1a; 1、cmd打开命令窗口&#xff0c;输入adb devices&#xff0c;确认能检测到模拟器 示例&#xff1a;List of devices attached emulator-5554 device 2、…

学习软件测试的第十五天

1.会写测试用例吗&#xff1f;测试用例有什么要素“会的&#xff0c;我写过多个功能测试和接口测试的测试用例。我写用例的时候会根据需求文档或原型图分析测试点&#xff0c;然后从正常流程、异常流程、边界情况等方面设计测试场景。每条用例我都会包含&#xff1a;用例编号、…

C++硬实时调度:原理、实践与最佳方案

在工业自动化、航空航天、医疗设备等领域&#xff0c;系统的实时性往往直接关系到生命安全和财产损失。C作为高性能编程语言&#xff0c;为硬实时系统开发提供了强大支持。本文将深入探讨C硬实时调度的核心技术&#xff0c;从操作系统原理到代码实现的全方位解析。 一、实时系统…

LeetCode 1156.单字符重复子串的最大长度

如果字符串中的所有字符都相同&#xff0c;那么这个字符串是单字符重复的字符串。 给你一个字符串 text&#xff0c;你只能交换其中两个字符一次或者什么都不做&#xff0c;然后得到一些单字符重复的子串。返回其中最长的子串的长度。 示例 1&#xff1a; 输入&#xff1a;text…