基于VGG网络的花分类
VGG-使用块的想法首先出现在牛津大学的 视觉几何组(visualgeometry Group (VGG)的 VGG网络 中。通过使用循环和子程序,可以很容易地在任何现代深度学习框架的代码中实现这些重复的结构。
模型
VGG 网络可以分为两部分:第一部分主要由卷积层和池化层组成,第二部分由全连接层组成。
前后有4个不同的网络 可根据cfgs搭建。
import torch.nn as nn import torch cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } #卷积网络 def make_features(cfg:list): in_channels=3 layers=[] for v in cfg: if v=='M': layers.append(nn.MaxPool2d(kernel_size=2,stride=2)) else: # layers.append(nn.Conv2d(in_channels=in_channels,out_channels=v,kernel_size=3,padding=1),nn.ReLU(True)) in_channels=v return nn.Sequential(*layers) class VGG(nn.Module): def __init__(self,features,num_classes=1000,initweights=False): super(VGG, self).__init__() #提取主干特征 self.features=features #全连接分类 self.classfier=nn.Sequential(nn.Dropout(p=0.5), nn.ReLU(True), nn.Linear(512*7*7,2048), nn.Dropout(p=0.5), nn.ReLU(True), nn.Linear(2048, 2048), nn.Dropout(p=0.5), nn.ReLU(True), nn.Linear(2048,num_classes)) #前向传播 def forward(self,x): x=self.features(x) x=torch.flatten(x,start_dim=1) x=self.classfier(x) return x def vgg(model_name="vgg16", **kwargs): try: cfg = cfgs[model_name] except: print("Warning: model number {} not in cfgs dict!".format(model_name)) exit(-1) model = VGG(make_features(cfg), **kwargs) return model
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950训练
import os import json import torch import torch.nn as nn from torchvision import transforms,datasets import torch.optim as optim import tqdm from VGG import vgg def main(): #转移到GPU训练 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) #数据(训练集,验证集)裁剪,垂直翻转-转化为Tensor-标准化 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # 获取数据根目录 image_path = os.path.join(data_root, "data_set", "flower_data") # 获取花图片目录 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) #转化为Tensor 获取类别 索引 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # 将dict写如json文件 json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) print("using {} images for training, {} images for validation.".format(train_num, val_num)) # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() #获取模型 model_name = "vgg16" net = vgg(model_name=model_name, num_classes=5, init_weights=True) net.to(device) #定义损失函数 loss_function = nn.CrossEntropyLoss() #使用优化器 optimizer = optim.Adam(net.parameters(), lr=0.0001) epochs = 30 best_acc = 0.0 #保存权重 save_path = './{}Net.pth'.format(model_name) train_steps = len(train_loader) for epoch in range(epochs): # 切换为train模式 net.train() running_loss = 0.0 #获取进度条 train_bar = tqdm(train_loader) for step, data in enumerate(train_bar): #获取图片,标签 images, labels = data #梯度清零 optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) #反向传播 loss.backward() #梯度更新 optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training') if __name__ == '__main__': main()
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 预测import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img_path = "../tulip.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r") class_indict = json.load(json_file) # create model model = vgg(model_name="vgg16", num_classes=5).to(device) # load model weights weights_path = "./vgg16Net.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) print(print_res) plt.show() if __name__ == '__main__': main()
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859相关知识
深度学习简单网络VGG鲜花分类
【VGG】简单应用
使用pytorch实现基于VGG 19预训练模型的鲜花识别分类器,准确度达到97%
基于深度学习和迁移学习的识花实践
基于轻量化VGG的植物病虫害识别
基于python编程的五种鲜花识别
【基于PyTorch实现经典网络架构的花卉图像分类模型】
基于 CNN 和迁移学习的农作物病害识别方法研究
基于深度学习的农作物病虫害识别系统
枸杞病害识别:基于区分深度置信网络的图像分类模型
网址: 基于VGG网络的花分类 https://www.huajiangbk.com/newsview516084.html
上一篇: 鸢尾花卉数据集的主成分分析与K近 |
下一篇: 一种基于机器学习的花朵种类识别方 |
推荐分享

- 1君子兰什么品种最名贵 十大名 4012
- 2世界上最名贵的10种兰花图片 3364
- 3花圈挽联怎么写? 3286
- 4迷信说家里不能放假花 家里摆 1878
- 5香山红叶什么时候红 1493
- 6花的意思,花的解释,花的拼音 1210
- 7教师节送什么花最合适 1167
- 8勿忘我花图片 1103
- 9橄榄枝的象征意义 1093
- 10洛阳的市花 1039