- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
基础任务:
1.条件生成对抗网络(CGAN)的基本原理
2.CGAN是如何实现条件控制的
3.学习本文CGAN代码,并跑通代码
进阶任务:
生成指定手势的图像
一、理论知识
条件生成对抗网络(CGAN):在生成对抗网络(GAN)的基础上进行了一些改进。对于原始GAN的生成器而言,其生成的图像数据是随机不可预测的,因此,我们无法控制网络的输出,在实际操作中的可控性不强。
针对上述原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了条件生成对抗网络,通过给原始生成对抗网络中的生成器G和判别器D增加额外的条件,如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。
条件生成对抗网络的本质:将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能在我们的掌控下更好地进行训练。
CGAN网络结构图:
由图可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用。
二、准备工作
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt
import datetimetorch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
1.导入数据
train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root="D:/study/data/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)def show_images(images):fig, ax = plt.subplots(figsize=(20, 20))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))def show_batch(dl):for images, _ in dl:show_images(images)breakshow_batch(train_loader)
遇到问题:
多线程缘故---去掉num_workers参数
image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100n_classes = 3
embedding_dim = 100
三、构建模型
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)
1.构建生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 16) )self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512), nn.LeakyReLU(0.2, inplace=True) )self.model = nn.Sequential( nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8), nn.ReLU(True), nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True), nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True), nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True), nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh() )def forward(self, inputs):noise_vector, label = inputs label_output = self.label_conditioned_generator(label) label_output = label_output.view(-1, 1, 4, 4) latent_output = self.latent(noise_vector) latent_output = latent_output.view(-1, 512, 4, 4) concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
输出:
from torchinfo import summary
summary(generator)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)
输出:
2.构建鉴别器
import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 3*128*128) )self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 64*2, 4, 3, 2, bias=False), nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8), nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False), nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False), nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(), nn.Dropout(0.4), nn.Linear(4608, 1), nn.Sigmoid() )def forward(self, inputs):img, label = inputslabel_output = self.label_condition_disc(label)label_output = label_output.view(-1, 3, 128, 128)concat = torch.cat((img, label_output), dim=1)output = self.model(concat)return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)
summary(discriminator)
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()
输出:
四、训练模型
1.定义损失函数
adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss
2.定义优化器
learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
3.训练模型
代码逻辑结构图:
1.首先设置了训练的总轮数和用于存储每轮训练中判别器和生成器损失的列表
2.然后进行GAN模型的训练。在每轮训练中,它首先从数据加载器中加载真实图像和标签,然后计算判别器对真实图像的损失,接着从噪声向量中生成假图像,计算判别器对假图像的损失,计算判别器总体损失并反向传播更新判别器参数,然后计算生成器的损失并反向传播更新生成器的参数
3.最后,它打印当前轮次的判别器和生成器的平均损失,并将当前轮次的判别器和生成器的平均损失保存到列表中
4.在每10轮训练后,它会将生成的假图像保存为图片文件,并将当前轮次的生成器和判别器的权重保存到文件
num_epochs = 100D_loss_plot, G_loss_plot = [], []for epoch in range(1, num_epochs + 1):D_loss_list, G_loss_list = [], []for index, (real_images, labels) in enumerate(train_loader):D_optimizer.zero_grad()real_images = real_images.to(device)labels = labels.to(device)labels = labels.unsqueeze(1).long()real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)D_total_loss.backward()D_optimizer.step()G_optimizer.zero_grad()G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)G_loss.backward()G_optimizer.step()print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))
输出:
五、模型分析
1.加载模型
generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspecdef generate_latent_points(latent_dim, n_samples, n_classes=3):x_input = randn(latent_dim * n_samples)z_input = x_input.reshape(n_samples, latent_dim)return z_inputdef interpolate_points(p1, p2, n_steps=10):ratios = linspace(0, 1, num=n_steps)vectors = list()for ratio in ratios:v = (1.0 - ratio) * p1 + ratio * p2vectors.append(v)return asarray(vectors)pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)output = None
for label in range(3):labels = torch.ones(10) * labellabels = labels.to(device)labels = labels.unsqueeze(1).long()print(labels.size())predictions = generator((interpolated, labels))predictions = predictions.permute(0,2,3,1)pred = predictions.detach().cpu()if output is None:output = predelse:output = np.concatenate((output,pred))
output.shape
nrow = 3
ncol = 10fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol) k = 0
for i in range(nrow):for j in range(ncol):pred = (output[k, :, :, :] + 1 ) * 127.5pred = np.array(pred) ax= plt.subplot(gs[i,j])ax.imshow(pred.astype(np.uint8))ax.set_xticklabels([])ax.set_yticklabels([])ax.axis('off')k += 1 plt.show()
五、总结
学习了条件生成对抗网络的基本原理和代码。了解CGAN是怎么实现条件控制。上次遇到的问题这次又忘记了,还是得多看。