首页 分享 GAN生成对抗网络:花卉生成

GAN生成对抗网络:花卉生成

来源:花匠小妙招 时间:2024-11-28 07:19
文章目录 简介一、GAN生成对抗网络基础知识二、数据集介绍三、代码实现参数设置数据处理搭建网络定义优化器与损失函数训练网络保存网络结果展示 总结

简介

本篇文章利用pytorch搭建GAN生成对抗网络实现花卉生成的任务

一、GAN生成对抗网络基础知识

关于GAN生成对抗网络的基础知识以下文章有详细讲解,可供参考:
GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)

二、数据集介绍

本文使用花卉数据集,该数据集包含了4317张图片,包含雏菊、蒲公英、玫瑰、向日葵、郁金香五种花卉,我已将数据集拆分为训练集和测试集两部分,本文仅使用了训练集部分,以下是数据集目录:
在这里插入图片描述在这里插入图片描述
数据集已放于以下链接,有需要可自行下载
花卉数据集

三、代码实现

参数设置

step1.参数continue_train:是否继续训练
step2.参数dir:训练集路径
step3.参数batch_size:单次训练图片量
step4.参数device:使用GPU
step5.参数epochs:训练周期
step6.参数generator_num:每k轮训练一次生成器
step7.参数discriminator_num:每k轮训练一次判别器

if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--continue_train', type=bool, default=False, help='continue training') parser.add_argument('--dir', type=str, default='./flowers/train', help='dataset path') parser.add_argument('--batch_size', type=int, default=50, help='batch size') parser.add_argument('--device', type=int, default=0, help='GPU id') parser.add_argument('--epochs', type=int, default=200, help='train epochs') parser.add_argument('--generator_num', type=int, default=5, help='train generator every k epochs') parser.add_argument('--discriminator_num', type=int, default=1, help='train discriminator every k epochs') args = parser.parse_args() main(args) 1234567891011

数据处理

step1.定义训练集中图像输入判别器前的transform操作
step2.准备Dataset与Dataloader

transform = transforms.Compose([ transforms.Resize((96, 96)), # 将图片resize至 96 * 96 transforms.ToTensor(), # 转换为张量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) data_set = datasets.ImageFolder(root=args.dir, transform=transform) data_loader = dataloader.DataLoader(dataset=data_set, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) print('already load data...') 123456789

搭建网络

step1.生成器使用反卷积,最终输出3 * 96 * 96大小的图片,且像素值 ∈ [ − 1 , 1 ] ∈[-1,1] ∈[−1,1]
step2.生成器使用卷积,最终输出判别为真的概率

class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() self.main = nn.Sequential( # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行 nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 512 × 4 × 4 (1-1)*1+1*(4-1)+0+1 = 4 nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 256 × 8 × 8 (4-1)*2-2*1+1*(4-1)+0+1 = 8 nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 128 × 16 × 16 nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), # 64 × 32 × 32 nn.ConvTranspose2d(64, 3, kernel_size=5, stride=3, padding=1, bias=False), nn.Tanh() # 3 * 96 * 96 ) def forward(self, input): return self.main(input) class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.main = nn.Sequential( nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 64 * 32 * 32 nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 128 * 16 * 16 nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 256 * 8 * 8 nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 512 * 4 * 4 nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False), nn.Sigmoid() # 输出一个概率 ) def forward(self, input): return self.main(input).view(-1)

123456789101112131415161718192021222324252627282930313233343536373839404142434445

定义优化器与损失函数

step1.生成器与判别器的优化器都使用Adam
step2.将损失函数使用二元交叉熵损失

optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999)) loss = nn.BCELoss() print('already prepared optimizer and loss_function...') 1234

训练网络

每discriminator_num轮:
step1.输入真图片让判别器鉴别
step2.生成器利用随机噪声生成图片,并让判别器鉴别
step3.计算判别器损失(真鉴别为真,假鉴别为假),反向传播后更新判别器参数
每generator_num轮:
step4.生成器利用随机噪声生成图片,并让判别器鉴别
step5.计算生成器损失(假鉴别为真),反向传播后更新生成器参数
step6.每100轮保存一次结果

print('start training...') for epoch in range(args.epochs): print('epoch:{}'.format(epoch + 1)) for i, data in enumerate(data_loader): if (i + 1) % args.discriminator_num == 0: optimizer_D.zero_grad() real_img = data[0] batchsize = len(real_img) real_img = real_img.cuda(args.device) out_D_real = model_D(real_img) real_labels = torch.ones(batchsize).cuda(args.device) loss_D_real = loss(out_D_real, real_labels) loss_D_real.backward() noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device) fake_img = model_G(noise) out_D_fake = model_D(fake_img) fake_labels = torch.zeros(batchsize).cuda(args.device) loss_D_fake = loss(out_D_fake, fake_labels) loss_D_fake.backward() optimizer_D.step() if (i + 1) % args.generator_num == 0: optimizer_G.zero_grad() real_img = data[0] batchsize = len(real_img) noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device) fake_img = model_G(noise) out_D_fake = model_D(fake_img) real_labels = torch.ones(batchsize).cuda(args.device) loss_G = loss(out_D_fake, real_labels) loss_G.backward() optimizer_G.step() if (epoch + 1) % 100 == 0: fix_noise = torch.randn(40, 100, 1, 1).cuda(args.device) final_img = model_G(fix_noise) final_img = final_img * 0.5 + 0.5 final_img = final_img.cpu() plt.figure(1) for i in range(40): img = final_img[i].detach().numpy() plt.subplot(5, 8, i+1) plt.imshow(np.transpose(img, (1, 2, 0))) plt.savefig("./outcome/{}.png".format(epoch + 1)) plt.show() print('end training...')

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647

保存网络

torch.save(model_G.state_dict(), './generator.pt') torch.save(model_D.state_dict(), './discriminator.pt') print('already saved model...') 123

结果展示

训练3000轮后得到结果如下:
在这里插入图片描述

总结

以上就是利用生成对抗网络实现图像生成的介绍,完整代码如下:

import argparse import torchvision.datasets as datasets import torch.utils.data.dataloader as dataloader import torchvision.transforms as transforms import torch.nn as nn import torch import numpy as np import matplotlib.pyplot as plt class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() self.main = nn.Sequential( # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行 nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 512 × 4 × 4 (1-1)*1+1*(4-1)+0+1 = 4 nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 256 × 8 × 8 (4-1)*2-2*1+1*(4-1)+0+1 = 8 nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 128 × 16 × 16 nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), # 64 × 32 × 32 nn.ConvTranspose2d(64, 3, kernel_size=5, stride=3, padding=1, bias=False), nn.Tanh() # 3 * 96 * 96 ) def forward(self, input): return self.main(input) class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.main = nn.Sequential( nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 64 * 32 * 32 nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 128 * 16 * 16 nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 256 * 8 * 8 nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 512 * 4 * 4 nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False), nn.Sigmoid() # 输出一个概率 ) def forward(self, input): return self.main(input).view(-1) def main(args): transform = transforms.Compose([ transforms.Resize((96, 96)), # 将图片resize至 96 * 96 transforms.ToTensor(), # 转换为张量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) data_set = datasets.ImageFolder(root=args.dir, transform=transform) data_loader = dataloader.DataLoader(dataset=data_set, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) print('already load data...') model_G = Generator() model_D = Discriminator() if args.continue_train == True: model_G.load_state_dict(torch.load('./generator.pt')) model_D.load_state_dict(torch.load('./discriminator.pt')) model_G.train() model_D.train() print('already prepared model...') optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999)) loss = nn.BCELoss() print('already prepared optimizer and loss_function...') if torch.cuda.is_available() == True: model_G.cuda(args.device) model_D.cuda(args.device) loss.cuda(args.device) print('already in GPU...') print('start training...') for epoch in range(args.epochs): print('epoch:{}'.format(epoch + 1)) for i, data in enumerate(data_loader): if (i + 1) % args.discriminator_num == 0: optimizer_D.zero_grad() real_img = data[0] batchsize = len(real_img) real_img = real_img.cuda(args.device) out_D_real = model_D(real_img) real_labels = torch.ones(batchsize).cuda(args.device) loss_D_real = loss(out_D_real, real_labels) loss_D_real.backward() noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device) fake_img = model_G(noise) out_D_fake = model_D(fake_img) fake_labels = torch.zeros(batchsize).cuda(args.device) loss_D_fake = loss(out_D_fake, fake_labels) loss_D_fake.backward() optimizer_D.step() if (i + 1) % args.generator_num == 0: optimizer_G.zero_grad() real_img = data[0] batchsize = len(real_img) noise = torch.randn(args.batch_size, 100, 1, 1).cuda(args.device) fake_img = model_G(noise) out_D_fake = model_D(fake_img) real_labels = torch.ones(batchsize).cuda(args.device) loss_G = loss(out_D_fake, real_labels) loss_G.backward() optimizer_G.step() if (epoch + 1) % 10 == 0: fix_noise = torch.randn(40, 100, 1, 1).cuda(args.device) final_img = model_G(fix_noise) final_img = final_img * 0.5 + 0.5 final_img = final_img.cpu() plt.figure(1) for i in range(40): img = final_img[i].detach().numpy() plt.subplot(5, 8, i+1) plt.imshow(np.transpose(img, (1, 2, 0))) plt.savefig("./outcome/{}.png".format(epoch + 1)) plt.show() print('end training...') torch.save(model_G.state_dict(), './generator.pt') torch.save(model_D.state_dict(), './discriminator.pt') print('already saved model...') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--continue_train', type=bool, default=False, help='continue training') parser.add_argument('--dir', type=str, default='./flowers/train', help='dataset path') parser.add_argument('--batch_size', type=int, default=50, help='batch size') parser.add_argument('--device', type=int, default=0, help='GPU id') parser.add_argument('--epochs', type=int, default=3000, help='train epochs') parser.add_argument('--generator_num', type=int, default=5, help='train generator every k epochs') parser.add_argument('--discriminator_num', type=int, default=1, help='train discriminator every k epochs') args = parser.parse_args() main(args)

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152

相关知识

基于生成对抗网络的植物景观生成设计——以花境平面图生成为例
ai生成手机景观软件有哪些:好用与免费推荐列表
全面盘点:主流AI动态风景生成软件与工具一览,满足你的创意需求
花卉码怎么生成
精准农业的智能化:大模型在作物监测与产量预测中的应用
18款AI绘画生成软件推荐,必备工具,一键生成绘画!
tensorflow生成图片标签
海报设计在线生成
ai生成字体库
AI人工智能大赛:技术比拼背后的科学与艺术盛宴

网址: GAN生成对抗网络:花卉生成 https://www.huajiangbk.com/newsview759325.html

所属分类:花卉
上一篇: 【WEB 3D】欧式花卉在线展示
下一篇: 南方有哪些花卉适合做精油?南方有

推荐分享