首页 分享 用python搭建一个花卉识别系统

用python搭建一个花卉识别系统

来源:花匠小妙招 时间:2024-10-31 02:23

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

import torch

import torch.nn as nn

from torchvision import transforms, datasets, utils

import matplotlib.pyplot as plt

import numpy as np

import torch.optim as optim

from model import AlexNet

import os

import json

import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(os.path.join("train.log"), "a") as log:

    log.write(str(device)+"n")

data_transform = {

    "train": transforms.Compose([transforms.RandomResizedCrop(224),      

                                 transforms.RandomHorizontalFlip(p=0.5), 

                                 transforms.ToTensor(),

                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

    "val": transforms.Compose([transforms.Resize((224, 224)), 

                               transforms.ToTensor(),

                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))        

image_path = data_root + "/jqsj/data_set/flower_data/"                         

train_dataset = datasets.ImageFolder(root=image_path + "/train",       

                                     transform=data_transform["train"])

train_num = len(train_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset,  

                                           batch_size=32,  

                                           shuffle=True,   

                                           num_workers=0)  

validate_dataset = datasets.ImageFolder(root=image_path + "/val",

                                        transform=data_transform["val"])

val_num = len(validate_dataset)

validate_loader = torch.utils.data.DataLoader(validate_dataset,

                                              batch_size=32,

                                              shuffle=True,

                                              num_workers=0)

flower_list = train_dataset.class_to_idx

cla_dict = dict((val, key) for key, val in flower_list.items())

json_str = json.dumps(cla_dict, indent=4)

with open('class_indices.json', 'w') as json_file:

    json_file.write(json_str)

net = AlexNet(num_classes=5, init_weights=True)      

net.to(device)                                       

loss_function = nn.CrossEntropyLoss()                

optimizer = optim.Adam(net.parameters(), lr=0.0002)  

save_path = './AlexNet.pth'

best_acc = 0.0

for epoch in range(150):

    net.train()                        

    running_loss = 0.0                 

    time_start = time.perf_counter()   

    for step, data in enumerate(train_loader, start=0): 

        images, labels = data  

        optimizer.zero_grad()  

        outputs = net(images.to(device))                

        loss = loss_function(outputs, labels.to(device))

        loss.backward()                                 

        optimizer.step()                                

        running_loss += loss.item()

        rate = (step + 1) / len(train_loader)          

        a = "*" * int(rate * 50)

        b = "." * int((1 - rate) * 50)

        with open(os.path.join("train.log"), "a") as log:

              log.write(str("rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss))+"n")

        print("rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")

    print()

    with open(os.path.join("train.log"), "a") as log:

              log.write(str('%f s' % (time.perf_counter()-time_start))+"n")

    print('%f s' % (time.perf_counter()-time_start))

    net.eval()   

    acc = 0.0 

    with torch.no_grad():

        for val_data in validate_loader:

            val_images, val_labels = val_data

            outputs = net(val_images.to(device))

            predict_y = torch.max(outputs, dim=1)[1] 

            acc += (predict_y == val_labels.to(device)).sum().item()   

        val_accurate = acc / val_num

        if val_accurate > best_acc:

            best_acc = val_accurate

            torch.save(net.state_dict(), save_path)

        with open(os.path.join("train.log"), "a") as log:

              log.write(str('[epoch %d] train_loss: %.3f  test_accuracy: %.3f n' %

              (epoch + 1, running_loss / step, val_accurate))+"n")

        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f n' %

              (epoch + 1, running_loss / step, val_accurate))

with open(os.path.join("train.log"), "a") as log:

      log.write(str('Finished Training')+"n")

print('Finished Training')

相关知识

python 花卉识别系统 用python搭建一个花卉识别系统(IT技术)
人工智能毕业设计基于python的花朵识别系统
深度学习基于python+TensorFlow+Django的花朵识别系统
用python搭建一个花卉识别系统
基于深度学习的花卉检测与识别系统(YOLOv5清新界面版,Python代码)
基于YOLOv8深度学习的102种花卉智能识别系统【python源码+Pyqt5界面+数据集+训练代码】目标识别、深度学习实战
基于YOLOv8深度学习的水稻害虫检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战
基于YOLOv8深度学习的智能玉米害虫检测识别系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战
花卉识别python
用python画花瓣

网址: 用python搭建一个花卉识别系统 https://www.huajiangbk.com/newsview304796.html

所属分类:花卉
上一篇: 如何使用 iPhone 相机识别
下一篇: 基于pytorch搭建AlexN

推荐分享