首页 分享 使用YOLOv8分类模型进行迁移训练对农作物病虫害识别分类

使用YOLOv8分类模型进行迁移训练对农作物病虫害识别分类

来源:花匠小妙招 时间:2024-11-12 07:03

YOLO模型能够进行图像检测、分类、分割、追踪等多种任务,我们将使用YOLO的分类模型进行一个十分基础的苹果叶片病虫害识别。

本文使用的pytorch版本为2.3.0+cu121

YOLOv8模型准备

首先为了能够调用yolo模型,我们首先要安装ultralytics库,直接使用pip安装即可:

pip install ultralytics

然后我还需要到官网下载yolo的预训练模型

yolov8预训练的分类模型主要有n、s、m、l、x五种,五种模型预测准确率不同,相应的内存与运算时间也不同,yolov8n内存最小、运算时间最快,但相应的准确率也较低,这里我选择的是yolov8m-cls预训练模型,大家可以根据自己的设备性能选择适合自己的预训练模型。

数据准备

本文使用的苹果叶片数据集总共有10种类别,每一种类别都代表着一种病害

yolo模型所要求的数据集格式类似于:

- 数据集/

- train/

- class1/

- class2/

- val/

- class1/

- class2/

- test/

- class1/

- class2/

 我们需要把数据集划分为train、valid和test三个文件夹,然后在每个文件夹下,再把图片数据划分到相应的类别(class1、class2)文件夹中。

我在这里使用的苹果叶片数据集已经以8:1:1的比例划分好了训练集、验证集和测试集,下载链接等下会放在文末。

训练模型

下面我们将开始训练模型

首先,导入需要的库

from ultralytics import YOLO

from torchvision.datasets import ImageFolder

from torchvision.transforms import ToTensor, Resize, Compose

from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score

import os

本人在训练的时候,遇到过 Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized这样的报错,据说是因为多个科学计算库不兼容导致的,因此我在文件中添加了下面代码,如果大家没有这样的问题可以不用理会

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

随后加载刚刚下好的预训练模型

model = YOLO('yolov8m-cls.pt')

然后就可以开始训练模型了

model.train(data="apple_dataset", epochs=50, imgsz=640)

data参数填你的数据集目录,注意:这里的apple_dataset应该是包含train、val、test三个文件夹的。

epochs参数为训练的周期,这里设置为50

imgsz为用来训练模型的图像大小,即模型会将图片重新缩放为imgsz的大小,这里我设置为640。imgsz会在一定程度上影响模型的预测准确率,imgsz越大,相应的准确率也会变高,但相应的训练速度也会变慢,所需内存也会变多。如果内存不足的话,大家可以把imgsz设置为448或256。

随后模型就会开始训练了

模型训练完成后会在当前目录下创建一个runs文件夹,里面的classify中记录了每次分裂模型训练的有关文件。

其中args.yaml记录了这次训练所设置的模型参数;confusion_matrix则是混淆矩阵,可以用来评估模型;events.out.tfevents开头的文件则是tensorboard的日志文件,可以通过加载这个文件来查看训练过程中超参数、模型损失等数据的变化;其他的还包括一些训练过程中每个batch具体的训练图片以及验证过程中的验证结果。

weight文件夹则是模型最后一个epoch时的checkpoint和模型表现最好的checkpoint。

测试模型

模型训练完毕后,我们可以看看模型在测试集上的表现

test_dir = "apple_datasettest"

transform = Compose([

Resize((640, 640)),

ToTensor(),

])

test_datasets = ImageFolder(test_dir, transform=transform)

test_dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False)

preds = []

labels = []

for batch in test_dataloader:

img, label = batch

result = model.predict(img, verbose=False)

preds.append(result[0].probs.top1)

labels.append(label.item())

accuracy = accuracy_score(preds, labels)

print(f'Accuracy: {accuracy * 100:.2f} %')

 

可以看到模型在最后测试集的准确率上由99%,说明结果还不错

我们还可以随便选择一张图片,看看模型具体的预测结果

predict_img = "apple_datasettestBlack rotBlack rot (29).JPG"

predict_result = model.predict(predict_img, imgsz=640)

labels = predict_result[0].names

predictt_label = predict_result[0].probs.top5

predict_probility = predict_result[0].probs.top5conf.cpu().numpy()

for i in range(5):

print(f"是{labels[predictt_label[i]]}的概率为{predict_probility[i] * 100:.2f}%")

 最后的预测结果符合真实结果。

完整代码如下:

from ultralytics import YOLO

from torchvision.datasets import ImageFolder

from torchvision.transforms import ToTensor, Resize, Compose

from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score

import os

//os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

model = YOLO('yolov8m-cls.pt')

model.train(data="apple_dataset", epochs=50, imgsz=640)

// 预测数据

test_dir = "apple_datasettest"

transform = Compose([

Resize((640, 640)),

ToTensor(),

])

test_datasets = ImageFolder(test_dir, transform=transform)

test_dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False)

preds = []

labels = []

for batch in test_dataloader:

img, label = batch

result = model.predict(img, verbose=False)

preds.append(result[0].probs.top1)

labels.append(label.item())

accuracy = accuracy_score(preds, labels)

print(f'Accuracy: {accuracy * 100:.2f} %')

// 预测图片

predict_img = "apple_datasettestBlack rotBlack rot (29).JPG"

predict_result = model.predict(predict_img, imgsz=640)

labels = predict_result[0].names

predictt_label = predict_result[0].probs.top5

predict_probility = predict_result[0].probs.top5conf.cpu().numpy()

for i in range(5):

print(f"是{labels[predictt_label[i]]}的概率为{predict_probility[i] * 100:.2f}%")

使用的苹果叶片数据集:

链接: https://pan.baidu.com/s/1_FO_005Q6i-nzR-Apx2I2Q?pwd=8mig 提取码: 8mig 

本文只是使用yolo模型完成了一个简单的数据分类任务,大家感兴趣的话可以深入了解一下yolo模型。除了分类任务外,yolo模型在目标检测和目标追踪等任务上也具有不错的表现。

相关知识

使用迁移学习对花卉进行分类
基于深度学习的农作物害虫检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的玉米病虫害检测系统(Python+PySide6界面+训练代码)
农作物病虫害识别进展概述
基于深度学习的玉米病虫害检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
YOLOv8系列】(七)毕设实战:YOLOv8+Pyqt5实现鲜花智能分类系统
智能识别花生病虫害:应用迁移学习与CNN
农作物病虫害识别技术的发展综述
基于深度学习的田间杂草检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
基于深度学习的植物叶片病害识别系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

网址: 使用YOLOv8分类模型进行迁移训练对农作物病虫害识别分类 https://www.huajiangbk.com/newsview504798.html

所属分类:花卉
上一篇: 【达人分享】月季控必读的月季全年
下一篇: 月季类苗木花卉的修剪方法

推荐分享