TensorFlow深度学习实战——DCGAN详解与实现

TensorFlow深度学习实战——DCGAN详解与实现

    • 0. 前言
    • 1. DCGAN 架构
    • 2. 构建 DCGAN 生成手写数字图像
      • 2.1 生成器与判别器架构
      • 2.2 构建 DCGAN
    • 相关链接

0. 前言

深度卷积生成对抗网络 (Deep Convolutional Generative Adversarial Network, DCGAN) 是一种基于生成对抗网络 (Generative Adversarial Network, GAN) 的深度学习模型,主要用于生成图像。它结合了卷积神经网络 (Convolutional Neural Network,CNN) 和生成对抗网络的优势,以更高效地生成质量更高的图像。

1. DCGAN 架构

深度卷积生成对抗网络 (Deep Convolutional Generative Adversarial Network, DCGAN) 引入了卷积神经网络 (Convolutional Neural Network,CNN) 的结构,主要设计思想是使用卷积层而不使用池化层或分类层。使用卷积的步幅参数和转置卷积执行下采样(维度减少)和上采样(维度增加)。

相比于原始生成对抗网络 (Generative Adversarial Network, GAN),DCGAN 的主要变化包括:

  • 网络完全由卷积层组成。池化层替换为步幅卷积(即,在使用卷积层时,将步幅从 1 增加为 2 )用于判别器,而生成器使用转置卷积
  • 移除卷积后的全连接分类层
  • 为了提高训练的稳定性,在每个卷积层后使用批归一化

DCGAN 的基本思想与原始 GAN 相同,生成器接受 100 维的噪声输入,经过全连接层后重塑形状后,通过卷积层处理,生成器架构如下:

生成器架构

判别器接收图像(可以是生成器生成的图像或来自真实数据集的图像),图像经过卷积处理和批归一化处理。在每一步卷积中通过步幅参数进行下采样。卷积层的最终输出展平后,输入到一个具有单个神经元的分类层:

判别器

生成器和判别器组合在一起形成 DCGAN。训练过程与原始 GAN 相同,首先在一个批数据上训练判别器,然后冻结判别器,训练生成器,并重复以上过程。实践证明,使用学习率为 0.002Adam 优化器能得到更稳定的结果。接下来,使用 Tensorflow 实现一个用于生成 MNIST 手写数字图像的 DCGAN

2. 构建 DCGAN 生成手写数字图像

在本节中,构建一个用于生成 MNIST 手写数字图像的 DCGAN

2.1 生成器与判别器架构

生成器通过顺序添加网络层构建。第一层是一个全连接层,接受 100 维的噪声作为输入,全连接层将 100 维的输入扩展为一个大小为 128 × 7 × 7 的一维向量。这样做的目的是为了最终得到大小为 28 × 28 的输出,也就是 MNIST 手写数字图像的标准大小。该向量重塑为一个大小为 7 × 7 × 128 的张量,然后使用 TensorFlowUpSampling2D 层进行上采样。需要注意的是,该层只是通过将行和列翻倍来放大图像,并没有可训练权重,因此计算开销较小。
Upsampling2D 层将 7 × 7 × 128 (行 × 列 × 通道)的图像的行和列翻倍,得到大小 14 × 14 × 128 的输出。上采样后的图像传递给一个卷积层,卷积层学习填充上采样图像中的细节,卷积的输出传递到批归一化层。批归一化后的输出经过 ReLU 激活。重复以上结构,即:上采样-卷积-批归一化-ReLU。在生成器中,具有两个这样的结构,第一个卷积层中使用 128 个卷积核,第二个使用 64 个卷积核。最终输出使用一个卷积层,使用尺寸为 3 x 3 的单个卷积核和 tanh 激活函数,生成 28 × 28 × 1 的图像:

    def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)

生成器模型架构如下:

生成器架构

也可以使用转置卷积层,转置卷积层不仅对输入图像进行上采样,而且在训练过程中学习如何填充细节。因此,可以用一个转置卷积层来替代上采样和卷积层,转置卷积层执行的是反卷积操作。
接下来,构建判别器。判别器类似于标准卷积神经网络,但区别在于,使用步幅为 2 的卷积层来代替最大池化层。还添加了 dropout 层以避免过拟合,并使用批归一化以提高准确性和加快收敛速度,激活函数使用 leaky ReLU。在判别器中,使用了三个卷积层,分别具有 3264128 个卷积核。最后一个卷积层的输出展平后传递给一个具有单个单元的全连接层。输出用于将图像分类为真实图像或伪造图像:

    def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

判别器模型架构如下:

判别器架构

2.2 构建 DCGAN

通过将生成器和判别器组合在一起得到完整的 GAN

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adamimport matplotlib.pyplot as plt
import sys
import numpy as npclass DCGAN():def __init__(self, rows, cols, channels, z = 10):# Input shapeself.img_rows = rowsself.img_cols = colsself.channels = channelsself.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = zoptimizer_1 = Adam(0.0002, 0.5)optimizer_2 = Adam(0.0002, 0.5)# Build and compile the discriminatorself.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer_1,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise as input and generates imgsz = Input(shape=(self.latent_dim,))img = self.generator(z)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated images as input and determines validityvalid = self.discriminator(img)# The combined model  (stacked generator and discriminator)# Trains the generator to fool the discriminatorself.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_2)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(Activation("relu"))model.add(Conv2D(self.channels, kernel_size=3, padding="same"))model.add(Activation("tanh"))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):model = Sequential()model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))model.add(BatchNormalization(momentum=0.8))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

使用 binary_crossentropy 损失函数定义生成器和判别器的损失。生成器和判别器的优化器在初始化方法中定义。最后,定义了一个 TensorFlow 检查点,用于在模型训练过程中保存生成器和判别器模型。
DCGAN 的训练过程与原始 GAN 相同,在每一步中,首先将随机噪声输入到生成器中。生成器的输出与真实图像用于训练判别器,然后训练生成器,使其生成能够欺骗判别器的图像。GAN 的训练通常需要几百到数千个训练 epoch

    def train(self, epochs, batch_size=256, save_interval=50):# Load the dataset(X_train, _), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and generate a batch of new imagesnoise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)# Train the discriminator (real classified as ones and generated as zeros)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator# ---------------------# Train the generator (wants discriminator to mistake images as real)g_loss = self.combined.train_on_batch(noise, valid)# Plot the progressprint ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % save_interval == 0:self.save_imgs(epoch)

最后,定义辅助函数保存图像:

    def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/dcgan_mnist_%d.png" % epoch)plt.close()

训练 DCGAN 模型:

dcgan = DCGAN(28,28,1)
dcgan.train(epochs=5000, batch_size=128, save_interval=50)

随着训练的进行,GAN 学习生成手写数字的能力逐渐增强:

训练监控

在第 50 个训练 epoch,生成的手写数字图像质量有了显著提升:

结果图像

下图是将 DCGAN 应用到名人图像数据集中的一些生成结果:

生成结果

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络
TensorFlow深度学习实战(9)——构建VGG模型实现图像分类
TensorFlow深度学习实战(10)——迁移学习详解
TensorFlow深度学习实战(11)——风格迁移详解
TensorFlow深度学习实战(14)——循环神经网络详解
TensorFlow深度学习实战(15)——编码器-解码器架构
TensorFlow深度学习实战(16)——注意力机制详解
TensorFlow深度学习实战(23)——自编码器详解与实现
TensorFlow深度学习实战(24)——卷积自编码器详解与实现
TensorFlow深度学习实战(25)——变分自编码器详解与实现
TensorFlow深度学习实战(26)——生成对抗网络详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89303.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot 使用MyBatisPlus

引入依赖<dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi3-jakarta-spring-boot-starter</artifactId><version>4.3.0</version> </dependency>写一个interface 继承basemapMapper public in…

Git 中如何查看提交历史?常用命令有哪些?

回答重点在 Git 中&#xff0c;我们可以使用 git log 命令来查看提交历史。这个命令会列出所有的提交记录&#xff0c;显示每个提交的哈希值、作者信息、提交时间和提交信息。常用的 git log 命令及其选项有&#xff1a;1&#xff09; git log &#xff1a;显示完整的提交历史。…

Flink数据流高效写入MySQL实战

这段代码展示了如何使用 Apache Flink 将数据流写入 MySQL 数据库&#xff0c;并使用了 JdbcSink 来实现自定义的 Sink 逻辑。以下是对代码的详细解析和说明&#xff1a;代码结构包声明&#xff1a;package sink定义了代码所在的包。导入依赖&#xff1a;导入了必要的 Flink 和…

MATLAB下载安装教程(附安装包)2025最新版(MATLAB R2024b)

文章目录前言一、MATLAB R2024b下载二、MATLAB下载安装教程前言 MATLAB R2024b 的推出&#xff0c;进一步提升了其在工程实践中的实用性和专业性。它不仅提供了更多针对特定工程领域的解决方案&#xff0c;还在性能和兼容性方面进行了显著改进。 本教程将一步一步引导完成 MA…

Linux 基础命令学习,立即上手Linux操作

Linux 基础命令学习本文挑选最常用、最容易上手的 Linux 命令。每条都附带一句话说明 真实示例&#xff0c;直接复制即可练习&#xff0c;零基础也能跟得上。1  先掌握 目录导航&#xff1a;pwd / ls / cdpwd – 显示当前所在目录 pwd # 输出示例 /home/yournamels ‑a…

Android构建流程与Transform任务

1. 完整构建流程概览 1.1 主要构建阶段 预构建阶段 → 代码生成阶段 → 资源处理阶段 → 编译阶段 → Transform阶段 → 打包阶段1.2 详细任务执行顺序 ┌─────────────────────────────────────────────────────────…

CKS认证 | Day6 监控、审计和运行时安全 sysdig、falco、审计日志

一、分析容器系统调用&#xff1a;Sysdig Sysdig&#xff1a;定位是系统监控、分析和排障的工具&#xff0c;在 linux 平台上&#xff0c;已有很多这方面的工具 如tcpdump、htop、iftop、lsof、netstat&#xff0c;它们都能用来分析 linux 系统的运行情况&#xff0c;而且还有…

Redis:持久化配置深度解析与实践指南

&#x1f9e0; 1、简述 Redis 是一款基于内存的高性能键值数据库&#xff0c;为了防止数据丢失&#xff0c;Redis 提供了两种主要的持久化机制&#xff1a;RDB&#xff08;快照&#xff09;和 AOF&#xff08;追加日志&#xff09;。本文将从原理到配置&#xff0c;再到实际项目…

共创 Rust 十年辉煌时刻:RustChinaConf 2025 赞助与演讲征集正式启动

&#x1f680; 共创 Rust 十年辉煌时刻&#xff1a;RustChinaConf 2025 赞助与演讲征集正式启动2025年&#xff0c;是 Rust 编程语言诞生十周年的里程碑时刻。在这个具有历史意义的节点&#xff0c;RustChinaConf 2025 携手 RustGlobal 首次登陆中国&#xff0c;联合 GOSIM HAN…

EMS4100芯祥科技USB3.1高速模拟开关芯片规格介绍

EMS4100一款适用于USB Type-C应用的二通道差分2:1/1:2 USB 3.1高速双向被动开关。该器件支持USB 3.1 Gen 1和Gen 2数据速率,具有高带宽、低串扰、宽供电电压范围等特点。EMS4100芯片内部框架&#xff1a;EMS4100主要特性&#xff1a;2-独立频道1&#xff1a;2/2&#xff1a;1 M…

HTML 常用语义标签与常见搭配详解

一、什么是语义标签&#xff1f; 语义标签是 HTML5 引入的一组具有特定含义的标签&#xff0c;用于描述页面中不同部分的内容类型&#xff0c;如页眉、导航栏、主内容区域、侧边栏、页脚等。相比传统的 <div> 和 <span>&#xff0c;语义标签更具表达力和结构化。 …

迁移学习的概念和案例

迁移学习概念 预训练模型 定义: 简单来说别人训练好的模型。一般预训练模型具备复杂的网络模型结构&#xff1b;一般是在大量的语料下训练完成的。 预训练语言模型的类别&#xff1a; 现在我们接触到的预训练语言模型&#xff0c;基本上都是基于transformer这个模型迭代而来…

DAOS系统架构-RDB

1. 概述 基于Raft共识算法和强大的领导地位策略&#xff0c;pool service和container service可以通过复制其内部的元数据来实现高可用。通过这种方法实现具有副本能力的服务可以容忍少数副本中的任何一个出现故障。通过将每个服务的副本分布在容灾域中&#xff0c;pool servic…

深入GPU硬件架构及运行机制

转自深入GPU硬件架构及运行机制 - 0向往0 - 博客园&#xff0c;基本上是其理解。 一、GPU概述 1.1 GPU是什么&#xff1f; GPU全称是Graphics Processing Unit&#xff0c;图形处理单元。它的功能最初与名字一致&#xff0c;是专门用于绘制图像和处理图元数据的特定芯片&…

数值计算库:Eigen与Boost.Multiprecision全方位解析

在科学计算、工程模拟、机器学习等领域&#xff0c;高效的数值计算能力是构建高性能应用的基石。C作为性能优先的编程语言&#xff0c;拥有众多优秀的数值计算库&#xff0c;其中Eigen和Boost.Multiprecision是两个极具代表性的工具。本文将深入探讨这两个库的核心特性、使用场…

第十八节:第三部分:java高级:反射-获取构造器对象并使用

Class提供的获取类构造器的方法以及获取类构造器的作用代码&#xff1a;掌握获取类的构造器&#xff0c;并对其进行操作 Cat类 package com.itheima.day9_reflect;public class Cat {private String name;private int age;private Cat(String name, int age) {this.name name;…

集中打印和转换Office 批量打印精灵:Word/Excel/PDF 全兼容,效率翻倍

各位办公小能手们&#xff01;你们平时办公的时候&#xff0c;是不是经常要打印一堆文件&#xff0c;烦得要命&#xff1f;别慌&#xff0c;今天我给大家介绍一款超厉害的神器——Office批量打印精灵&#xff01; 软件下载地址安装包 这玩意儿啊&#xff0c;是专门为高效办公设…

docker的搭建

一、安装docker使用以下命令进行安装dockerapt-get install docker.io docker-compose使用以下命令进行查看docker是否开启systemctl status docker由此可见&#xff0c;docker没有打开&#xff0c;进行使用命令打开。systemctl start docker再次查看是否开启。肉眼可见&#x…

数据库管理-第349期 Oracle DB 23.9新特性一览(20250717)

数据库管理349期 2025-07-17数据库管理-第349期 Oracle DB 23.9新特性一览&#xff08;20250717&#xff09;1 JavaScript过程和函数的编译时语法检查2 不再需要JAVASCRIPT上的EXECUTE权限3 GROUP BY ALL4 使用SQL创建并测试UUID5 IVF索引在线重组6 JSON到二元性迁移器&#xf…

将CSDN文章导出为PDF

作者&#xff1a;翟天保Steven 版权声明&#xff1a;著作权归作者所有&#xff0c;商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处前言在日常学习和技术积累过程中&#xff0c;我们经常会在 CSDN 等技术博客平台上阅读高质量的技术文章。然而&#xff0c;网页阅读…