首页 分享 小白学Pytorch使用(4

小白学Pytorch使用(4

来源:花匠小妙招 时间:2024-12-12 15:51

任务背景

利用resnet18网络结构及预训练模型参数进行102类别的花数据集分类,迁移学习冻结resnet18输出层外的权重参数更新,保存最好的训练模型pt文件。
数据如下:
花数据集+json文件+2个迁移学习模型

一、导入库

import os import matplotlib.pyplot as plt import numpy as np import torch from torch import nn import torch.optim as optim import torchvision # torchvision中的transforms模块自带数据增强、数据预处理功能;models预训练模型,如resnet模型;datasets文件夹 from torchvision import transforms, models, datasets import imageio import time import warnings warnings.filterwarnings("ignore") import random import sys import copy import json from PIL import Image

123456789101112131415161718

二、数据导入

# 导入数据路径——改为自己的数据路径 data_dir = r'D:咕泡人工智能-配套资料配套资料4.第四章 深度学习核⼼框架PyTorch第五章:图像识别模型与训练策略(重点)flower_data' train_dir = data_dir + '/train' valid_dir = data_dir + '/valid' 1234'

三、数据预处理

# 数据集量较少,需进行数据增强(Data Augmentation):旋转、裁剪、翻转(水平/垂直)、平移 data_transforms = { 'train': # Compose()顺序执行以下操作 transforms.Compose([ # 数据集尺寸不同,统一尺寸,可正方形(常用)、长方形,一般64、128、224、256 # 数据越小易损失特征,但训练速度加快;CPU一般96/64 transforms.Resize([96, 96]), transforms.RandomRotation(45), #随机旋转,-45到45度之间随机选 transforms.CenterCrop(64), #从中心开始裁剪,实际输入网络的图像大小以裁剪后大小为准,如裁剪后为64*64,则输入大小为64*64 transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率,每张图都有50%可能性水平翻转 transforms.RandomVerticalFlip(p=0.5), #随机垂直翻转,每张图都有50%可能性垂直翻转 # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#(极端光照条件下使用,一般不使用)参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相 # transforms.RandomGrayscale(p=0.025),#(一般不使用)概率转换成灰度率,3通道就是R=G=B,转成RRR/GGG/BBB transforms.ToTensor(), #数据转换为tensor格式 # 标准化操作。由于数据量较少不具有代表性,选用自己的均值标准差结果不稳定,因此选择大数据中提供的标准差 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#r、g、b三通道的均值u,标准差b,(x-u)/b ]), # 验证集测试模型实际训练结果,不需要数据增强 'valid': transforms.Compose([ # 与训练集输入尺寸(裁剪后尺寸)相同 transforms.Resize([64, 64]), transforms.ToTensor(), # 与训练集均值标准差相同 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # 由于输入图像尺寸较小,batch_size可以大些。考虑电脑显存问题 batch_size = 128 # 将训练集和验证集文件夹与图像预处理操作对应起来————————image_datasets为字典类型,datasets数据以文件夹形式处理(数据集中的数据按类别划分了单独文件夹,因此以文件夹形式处理) image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']} # print(image_datasets) ''' {'train': Dataset ImageFolder Number of datapoints: 6552 Root location: D:咕泡人工智能-配套资料配套资料4.第四章 深度学习核⼼框架PyTorch第五章:图像识别模型与训练策略(重点)flower_datatrain StandardTransform Transform: Compose( Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=True) RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0) CenterCrop(size=(64, 64)) RandomHorizontalFlip(p=0.5) RandomVerticalFlip(p=0.5) ToTensor() Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ), 'valid': Dataset ImageFolder Number of datapoints: 818 Root location: D:咕泡人工智能-配套资料配套资料4.第四章 深度学习核⼼框架PyTorch第五章:图像识别模型与训练策略(重点)flower_datavalid StandardTransform Transform: Compose( Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=True) ToTensor() Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )} ''' # 将训练集和验证集文件夹中的数据打乱划分batch——————dataloaders为字典形式 dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']} # print(dataloaders) # {'train': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07130>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07100>} # 计算训练集和验证集分别数据数目 dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # 获取分类标签 class_names = image_datasets['train'].classes # print(class_names) ''' ['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99'] ''' # 打开cat_to_name.json文件,文件中有数字对应的实际类别名称——数据标签文件 with open('D:/咕泡人工智能-配套资料配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json', 'r') as f: cat_to_name = json.load(f) # print(cat_to_name) ''' {'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers', '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower', '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose', '99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose', '47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania', '90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy', '84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula', '72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia', '82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus', '88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower', '51': 'petunia'} '''

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100

四、导入Resnet网络——迁移学习

# 迁移学习:使用前人的网络结构和模型做训练 # 数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层 # 数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练 # 数据量较大时,整个模型不进行冻结,全部更新训练 # 加载models中提供的模型,并且直接用训练好的权重当作初始化参数 #可选网络结构比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],resnet网络效果较好 model_name = 'resnet' # 是否用人家训练好的特征来做 # 此项目冻结输出层前面所有部分,不进行训练更新 feature_extract = True # 是否用GPU训练——————固定写法 train_on_gpu = torch.cuda.is_available() if not train_on_gpu: print('CUDA is not available. Training on CPU ...') else: print('CUDA is available! Training on GPU ...') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152 model_ft = models.resnet18() # resnet18网络结构 # print(model_ft) ''' ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) # resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层 (fc): Linear(in_features=512, out_features=1000, bias=True) ) ''' # 自定义函数判断模型参数是否要进行更新 def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: # 对18层resnet网络结构中的权重参数进行遍历 for param in model.parameters(): # requires_grad反向传播时是否更新参数,False不进行更新 param.requires_grad = False # for param in model_ft.parameters(): # param.requires_grad = False # for name, param in model_ft.named_parameters(): # # 输出所有权重参数名称及参数值 # print(name) # print(param) # 输出层权重参数如下: ''' fc.weight Parameter containing: tensor([[ 0.0018, -0.0213, -0.0421, ..., 0.0064, -0.0055, -0.0221], [-0.0298, -0.0036, 0.0277, ..., 0.0181, -0.0311, -0.0340], [-0.0434, -0.0196, -0.0026, ..., 0.0229, 0.0110, 0.0345], ..., [-0.0097, 0.0290, 0.0391, ..., -0.0091, -0.0028, -0.0373], [ 0.0073, 0.0256, 0.0238, ..., 0.0004, -0.0275, -0.0347], [-0.0356, -0.0144, -0.0418, ..., 0.0020, -0.0338, 0.0351]]) fc.bias Parameter containing: tensor([ 0.0226, 0.0271, -0.0001, -0.0043, -0.0339, -0.0216, 0.0056, -0.0067, -0.0024, 0.0062, -0.0425, 0.0148, 0.0340, -0.0365, 0.0084, 0.0407, -0.0335, 0.0133, 0.0155, -0.0303, -0.0304, 0.0291, -0.0440, -0.0278, -0.0324, -0.0041, 0.0068, -0.0349, -0.0380, 0.0169, 0.0203, -0.0439, 0.0185, -0.0206, 0.0164, -0.0020, 0.0067, -0.0101, -0.0166, 0.0155, -0.0302, 0.0095, -0.0119, 0.0124, 0.0065, 0.0435, -0.0298, -0.0022, -0.0379, 0.0406, -0.0065, -0.0210, 0.0151, 0.0111, -0.0251, -0.0374, -0.0246, -0.0052, -0.0139, -0.0246, 0.0008, 0.0318, 0.0358, 0.0370, 0.0277, -0.0124, 0.0391, -0.0158, 0.0052, -0.0437, 0.0076, -0.0352, -0.0045, 0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224, 0.0410, -0.0063, 0.0408, -0.0062, -0.0034, -0.0084, -0.0306, 0.0077, 0.0284, 0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166, 0.0153, -0.0213, -0.0396, 0.0089, 0.0307, 0.0027, -0.0179, -0.0098, 0.0343, 0.0375, -0.0432, 0.0126, -0.0060, 0.0370, -0.0030, 0.0219, 0.0140, -0.0165, 0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205, 0.0268, 0.0429, -0.0260, 0.0252, 0.0016, -0.0357, 0.0052, -0.0040, 0.0173, -0.0442, -0.0089, -0.0116, 0.0166, -0.0381, 0.0143, 0.0032, -0.0010, -0.0092, 0.0130, -0.0224, -0.0258, 0.0187, 0.0179, 0.0061, 0.0297, 0.0404, 0.0218, 0.0116, 0.0345, -0.0209, -0.0256, 0.0091, -0.0302, -0.0306, 0.0239, 0.0412, 0.0408, 0.0253, -0.0016, -0.0182, -0.0015, 0.0042, 0.0150, -0.0024, 0.0138, -0.0164, 0.0076, -0.0181, 0.0339, 0.0179, -0.0289, -0.0327, -0.0399, 0.0419, 0.0386, -0.0345, -0.0321, -0.0413, 0.0390, -0.0339, 0.0359, 0.0012, 0.0298, 0.0134, -0.0128, -0.0251, 0.0435, -0.0231, -0.0432, 0.0385, -0.0423, -0.0137, 0.0170, -0.0412, -0.0413, 0.0114, 0.0428, -0.0425, -0.0089, -0.0290, -0.0112, 0.0277, -0.0109, 0.0083, 0.0324, -0.0163, -0.0389, -0.0206, 0.0052, -0.0091, 0.0151, 0.0244, 0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107, -0.0185, 0.0059, -0.0198, -0.0395, -0.0129, -0.0363, 0.0408, 0.0418, -0.0230, -0.0122, -0.0168, 0.0281, 0.0338, 0.0321, 0.0219, -0.0041, 0.0307, -0.0244, -0.0007, -0.0307, 0.0387, 0.0051, 0.0417, -0.0241, -0.0165, 0.0186, 0.0210, 0.0373, 0.0192, 0.0415, -0.0230, -0.0091, -0.0324, -0.0416, 0.0304, 0.0065, 0.0398, 0.0036, -0.0232, -0.0392, 0.0109, -0.0108, -0.0320, -0.0032, -0.0138, 0.0357, -0.0247, 0.0363, -0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101, 0.0210, -0.0243, 0.0269, 0.0128, -0.0142, -0.0385, 0.0185, 0.0233, -0.0051, -0.0172, 0.0224, -0.0434, 0.0078, 0.0159, -0.0201, -0.0363, -0.0246, 0.0044, -0.0306, -0.0377, -0.0313, 0.0366, 0.0368, -0.0303, 0.0066, 0.0322, -0.0143, 0.0343, 0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189, -0.0236, 0.0292, 0.0194, 0.0372, -0.0055, 0.0430, 0.0243, -0.0126, 0.0208, 0.0273, 0.0145, 0.0269, 0.0020, -0.0070, -0.0102, 0.0016, -0.0191, 0.0397, -0.0001, -0.0044, -0.0360, 0.0095, 0.0357, 0.0089, 0.0235, -0.0244, 0.0088, 0.0222, 0.0259, 0.0096, 0.0189, 0.0390, 0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088, 0.0085, 0.0018, -0.0132, 0.0289, 0.0287, -0.0400, 0.0063, -0.0054, 0.0399, 0.0123, 0.0102, -0.0326, 0.0194, -0.0049, 0.0104, -0.0171, -0.0080, 0.0429, -0.0056, -0.0298, 0.0064, 0.0341, -0.0191, 0.0132, -0.0174, 0.0435, 0.0035, 0.0093, -0.0321, -0.0366, 0.0307, 0.0088, -0.0395, 0.0357, 0.0032, -0.0149, -0.0247, 0.0124, 0.0436, 0.0126, -0.0156, 0.0050, 0.0109, 0.0183, -0.0404, -0.0018, -0.0104, -0.0395, 0.0390, -0.0306, 0.0261, -0.0244, -0.0253, -0.0329, 0.0334, 0.0429, -0.0138, -0.0190, -0.0235, -0.0204, -0.0393, 0.0217, -0.0332, -0.0438, -0.0294, -0.0158, -0.0103, 0.0238, -0.0419, -0.0408, 0.0113, 0.0366, 0.0221, -0.0190, -0.0244, 0.0335, 0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114, 0.0137, 0.0019, -0.0006, -0.0074, 0.0137, -0.0260, -0.0037, 0.0199, 0.0155, -0.0296, 0.0173, 0.0224, 0.0091, -0.0167, 0.0004, 0.0206, -0.0237, -0.0195, 0.0387, -0.0045, 0.0088, 0.0261, -0.0418, -0.0144, -0.0375, -0.0106, -0.0354, 0.0411, -0.0053, -0.0248, -0.0010, 0.0323, -0.0203, 0.0012, 0.0204, -0.0320, 0.0321, 0.0137, 0.0064, -0.0329, 0.0051, -0.0340, 0.0171, 0.0422, 0.0266, 0.0238, -0.0164, 0.0103, -0.0413, -0.0355, 0.0127, 0.0207, -0.0240, 0.0398, 0.0323, 0.0217, 0.0030, 0.0396, 0.0327, -0.0060, 0.0312, -0.0117, -0.0079, 0.0095, 0.0423, 0.0010, 0.0018, 0.0233, 0.0434, -0.0210, 0.0049, -0.0072, 0.0031, 0.0052, -0.0292, -0.0217, -0.0253, 0.0274, -0.0230, -0.0342, -0.0149, 0.0137, -0.0057, 0.0344, -0.0327, 0.0370, 0.0142, -0.0194, 0.0109, 0.0366, -0.0046, -0.0203, 0.0088, -0.0117, 0.0263, 0.0020, -0.0335, 0.0387, -0.0196, 0.0386, -0.0433, -0.0012, 0.0138, -0.0383, 0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297, 0.0331, 0.0372, -0.0275, 0.0007, 0.0266, 0.0006, 0.0255, -0.0035, -0.0365, 0.0165, -0.0423, 0.0054, 0.0041, 0.0026, -0.0338, 0.0436, -0.0385, 0.0346, -0.0295, -0.0315, 0.0096, 0.0376, 0.0166, -0.0351, 0.0077, 0.0028, 0.0139, 0.0394, 0.0115, -0.0231, -0.0134, 0.0167, 0.0160, -0.0003, -0.0152, 0.0399, 0.0367, -0.0424, 0.0104, -0.0025, -0.0118, 0.0020, -0.0115, -0.0253, -0.0290, 0.0042, -0.0025, -0.0155, 0.0101, -0.0170, 0.0135, 0.0283, 0.0441, 0.0294, -0.0083, -0.0428, -0.0267, -0.0247, -0.0344, 0.0177, 0.0173, 0.0029, -0.0369, -0.0150, -0.0193, 0.0428, -0.0401, 0.0377, 0.0138, -0.0136, -0.0104, 0.0325, 0.0335, -0.0152, -0.0014, -0.0287, 0.0375, -0.0426, 0.0393, 0.0016, -0.0244, 0.0394, 0.0212, -0.0019, -0.0024, -0.0212, 0.0249, -0.0244, -0.0144, 0.0227, 0.0073, 0.0412, -0.0194, -0.0300, 0.0084, -0.0273, 0.0157, -0.0270, -0.0411, 0.0153, -0.0040, -0.0268, 0.0048, 0.0164, -0.0165, -0.0089, 0.0328, 0.0345, -0.0013, -0.0226, -0.0097, -0.0234, 0.0156, 0.0236, 0.0076, -0.0349, -0.0442, -0.0293, 0.0132, -0.0420, -0.0020, 0.0059, -0.0231, -0.0187, -0.0304, 0.0270, 0.0382, 0.0407, -0.0244, 0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330, 0.0249, 0.0247, 0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441, 0.0407, -0.0027, 0.0177, 0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331, 0.0008, -0.0392, -0.0378, 0.0418, 0.0093, -0.0442, 0.0419, -0.0233, 0.0076, 0.0394, 0.0428, 0.0027, -0.0117, -0.0161, 0.0153, 0.0035, 0.0222, 0.0375, -0.0300, 0.0200, 0.0285, -0.0120, 0.0074, 0.0054, -0.0117, -0.0161, -0.0088, 0.0428, 0.0103, 0.0136, 0.0376, -0.0172, 0.0086, 0.0183, 0.0381, 0.0109, -0.0177, -0.0416, -0.0351, 0.0338, -0.0404, 0.0324, 0.0041, -0.0172, 0.0002, -0.0415, 0.0423, 0.0172, -0.0068, 0.0370, 0.0321, 0.0040, 0.0401, 0.0030, 0.0238, 0.0102, -0.0437, 0.0204, 0.0432, 0.0399, -0.0186, -0.0067, 0.0205, -0.0098, 0.0009, 0.0019, 0.0317, -0.0237, 0.0062, 0.0097, -0.0312, -0.0033, 0.0028, -0.0021, 0.0006, -0.0320, -0.0196, 0.0363, 0.0285, 0.0321, -0.0132, 0.0344, -0.0232, 0.0379, -0.0166, 0.0311, -0.0174, 0.0431, 0.0006, -0.0066, -0.0344, -0.0076, 0.0245, 0.0286, -0.0388, 0.0114, 0.0204, 0.0137, 0.0387, 0.0206, -0.0392, -0.0109, 0.0375, 0.0269, 0.0232, -0.0362, 0.0235, -0.0137, 0.0303, 0.0389, -0.0068, 0.0306, 0.0273, 0.0264, -0.0074, 0.0315, -0.0291, -0.0027, -0.0061, 0.0188, -0.0123, -0.0360, -0.0266, 0.0292, 0.0248, 0.0127, -0.0251, -0.0426, 0.0066, -0.0005, -0.0162, -0.0236, -0.0330, 0.0339, 0.0319, 0.0135, 0.0260, -0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042, 0.0086, 0.0291, 0.0283, 0.0028, -0.0137, -0.0218, 0.0109, -0.0052, -0.0213, -0.0108, -0.0354, -0.0012, 0.0006, -0.0111, -0.0066, 0.0401, -0.0313, 0.0203, 0.0060, 0.0339, 0.0072, 0.0148, 0.0047, 0.0363, -0.0117, 0.0338, 0.0238, -0.0367, 0.0093, 0.0048, 0.0222, 0.0164, 0.0399, 0.0005, 0.0056, 0.0023, -0.0330, 0.0241, -0.0403, 0.0047, 0.0312, -0.0148, 0.0146, 0.0268, 0.0439, -0.0265, -0.0044, 0.0441, 0.0395, -0.0164, 0.0117, 0.0343, 0.0355, 0.0243, -0.0091, 0.0083, -0.0213, -0.0196, -0.0082, 0.0070, 0.0403, -0.0407, 0.0270, 0.0375, 0.0207, 0.0207, -0.0049, 0.0020, 0.0429, -0.0432, -0.0368, -0.0040, 0.0035, 0.0114, -0.0345, -0.0286, 0.0232, 0.0342, 0.0437, 0.0193, 0.0030, 0.0375, 0.0161, 0.0172, -0.0069, -0.0118, -0.0235, -0.0155, -0.0029, -0.0035, 0.0372, -0.0343, 0.0183, -0.0057, -0.0093, -0.0322, -0.0303, 0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055, 0.0315, -0.0020, 0.0268, -0.0305, -0.0286, -0.0083, 0.0015, -0.0226, 0.0249, -0.0133, -0.0359, 0.0393, 0.0058, -0.0354, 0.0011, 0.0424, 0.0363, 0.0405, 0.0006, -0.0422, 0.0363, -0.0298, -0.0319, -0.0131, 0.0021, 0.0276, -0.0302, -0.0350, -0.0433, 0.0185, 0.0263, 0.0307, -0.0093, 0.0377, 0.0031, -0.0115, -0.0297, -0.0327, 0.0103, 0.0179, -0.0071, 0.0029, -0.0345, -0.0335, -0.0184, 0.0426, 0.0169, 0.0039, -0.0071, 0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199, 0.0091, -0.0162, -0.0269, 0.0172, 0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167]) ''' # 输出可以进行梯度更新的权重参数名称——————————因为前面修改了所有的参数更新为False,此判断无输出 # if param.requires_grad == True: # print('t', name)

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269

五、模型训练及参数配置

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True): # 加载预训练模型 # resnet有18、50、101、152等层 # model_name.resnet18只有网络结构,没有预训练参数。pretrained是否加载预训练模型,pretrained=True下载resnet18到C:/user/cache model_ft = model_name.resnet18(pretrained=use_pretrained) # 冻结网络结构中所有的参数更新 set_parameter_requires_grad(model_ft, feature_extract) # 找到全连接层的输入参数:512 # resnet18的全连接层:(fc): Linear(in_features=512, out_features=1000, bias=True) num_ftrs = model_ft.fc.in_features # num_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新 model_ft.fc = nn.Linear(num_ftrs, num_classes) # # 输入大小根据自己配置来 # input_size = 64 # return model_ft, input_size return model_ft # input_size好像没用处,删去 # model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True) model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True) #GPU还是CPU计算 model_ft = model_ft.to(device) # 模型保存路径,名字自己起——————网络结构和权重参数 filename = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my.pt' # 是否训练所有层 # 将model_ft的所有权重参数保存到params_to_update params_to_update = model_ft.parameters() print("Params to learn:") if feature_extract: params_to_update = [] for name, param in model_ft.named_parameters(): # 此项目中只有新加的全连接层param.requires_grad为True,进行参数更新 if param.requires_grad == True: params_to_update.append(param) print("t", name) else: for name, param in model_ft.named_parameters(): if param.requires_grad == True: print("t", name) # 优化器设置,只训练params_to_update中的参数 optimizer_ft = optim.Adam(params_to_update, lr=1e-2) # 学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍 scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1) # 损失函数 criterion = nn.CrossEntropyLoss() # 训练函数 def train_model(model, dataloaders, criterion, optimizer, filename, num_epochs=25): # 计算训练起始时间 since = time.time() # 记录训练最好的那一次的准确率 best_acc = 0 # 判断模型放到CPU或者GPU model.to(device) # 训练过程中打印一堆损失和指标 # 验证准确率 val_acc_history = [] # 训练准确率 train_acc_history = [] # 训练损失 train_losses = [] # 验证损失 valid_losses = [] # 当前学习率 optimizer.param_groups是一个字典结构 LRs = [optimizer.param_groups[0]['lr']] # print(optimizer.param_groups) ''' [{'params': [Parameter containing: tensor([[-0.0065, -0.0276, 0.0032, ..., -0.0383, -0.0091, 0.0323], [ 0.0141, 0.0429, -0.0220, ..., 0.0234, 0.0301, -0.0281], [ 0.0064, -0.0427, 0.0039, ..., 0.0214, -0.0171, -0.0016], ..., [-0.0355, 0.0126, -0.0099, ..., -0.0322, -0.0201, 0.0245], [-0.0127, 0.0114, -0.0213, ..., 0.0270, -0.0070, -0.0315], [-0.0226, -0.0235, 0.0262, ..., -0.0109, 0.0241, 0.0084]], requires_grad=True), Parameter containing: tensor([-0.0118, 0.0029, -0.0184, 0.0226, 0.0082, -0.0320, -0.0046, 0.0358, -0.0234, 0.0430, 0.0245, 0.0431, -0.0127, -0.0231, 0.0230, 0.0357, -0.0181, 0.0389, 0.0127, 0.0343, 0.0044, 0.0217, -0.0323, -0.0211, 0.0309, 0.0416, -0.0317, -0.0248, 0.0093, -0.0324, -0.0115, 0.0181, -0.0190, -0.0005, 0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048, 0.0088, -0.0371, -0.0203, -0.0163, 0.0073, 0.0044, -0.0410, -0.0289, -0.0305, -0.0363, 0.0409, 0.0364, 0.0082, 0.0419, -0.0063, -0.0100, 0.0008, -0.0270, -0.0163, 0.0059, -0.0100, 0.0252, 0.0183, -0.0160, 0.0027, 0.0347, -0.0131, -0.0292, -0.0225, -0.0183, 0.0326, -0.0062, -0.0422, -0.0220, -0.0410, -0.0408, 0.0405, -0.0046, -0.0339, 0.0411, -0.0015, -0.0371, -0.0152, 0.0244, -0.0128, -0.0117, -0.0275, 0.0333, 0.0033, 0.0276, 0.0302, -0.0367, 0.0236, 0.0409, 0.0192, -0.0411, -0.0224, -0.0200, -0.0321, 0.0120, -0.0427, 0.0333], requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.01}] ''' # 查看优化器参数 # print(optimizer) ''' Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) capturable: False differentiable: False eps: 1e-08 foreach: None fused: None initial_lr: 0.01 lr: 0.01 maximize: False weight_decay: 0 ) ''' # 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state_dict()模型当前权重参数 best_model_wts = copy.deepcopy(model.state_dict()) # 一个个epoch来遍历 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # 训练和验证 for phase in ['train', 'valid']: if phase == 'train': model.train() # 训练 else: model.eval() # 验证 # 初始化损失和预测正确个数 running_loss = 0.0 running_corrects = 0 # 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分 for inputs, labels in dataloaders[phase]: # to(device)数据放到你的CPU或GPU inputs = inputs.to(device) labels = labels.to(device) # 梯度清零 optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) # 102个类别中概率最大值的下标 _, preds = torch.max(outputs, 1) # 训练阶段更新权重 if phase == 'train': loss.backward() optimizer.step() #完成一次迭代 # 累加计算总损失和总准确率 # input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度 running_loss += loss.item() * inputs.size(0) # 预测结果最大的和真实值是否一致 running_corrects += torch.sum(preds == labels.data) # 计算每个epoch的损失和准确率 epoch_loss = running_loss / len(dataloaders[phase].dataset) # 算平均 epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) # 一个epoch需要多少时间 time_elapsed = time.time() - since print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) # 得到最好那次的模型 if phase == 'valid' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) state = { 'state_dict': model.state_dict(), # 字典里key就是各层的名字,值就是训练好的权重 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), } torch.save(state, filename) if phase == 'valid': val_acc_history.append(epoch_acc) valid_losses.append(epoch_loss) # scheduler.step(epoch_loss)#学习率衰减 if phase == 'train': train_acc_history.append(epoch_acc) train_losses.append(epoch_loss) print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr'])) LRs.append(optimizer.param_groups[0]['lr']) print() scheduler.step() # 学习率衰减 # 总体运行花了多少时间 time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # 训练完后用最好的一次当做模型最终的结果,等着一会测试 model.load_state_dict(best_model_wts) return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs # 训练模型 model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, filename, num_epochs=20)

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206

训练过程展示:
训练起始结果
训练结束结果

六、验证最佳训练模型

# 加载最佳训练模型 checkpoint = torch.load(filename) best_acc = checkpoint['best_acc'] model_ft.load_state_dict(checkpoint['state_dict']) # 随机得到一个batch的验证数据进行测试 dataiter = iter(dataloaders['valid']) images, labels = next(dataiter) model_ft.eval() if train_on_gpu: output = model_ft(images.cuda()) else: output = model_ft(images) # 得到概率最大的一个 _, preds_tensor = torch.max(output, 1) # 判断数据是否在GPU,在的话数据转cpu中转ndarray类型,在的话直接转ndarray类型 preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy()) # 展示预测结果 def im_convert(tensor): """ 展示数据""" # 数据从cpu中克隆一份出来 image = tensor.to("cpu").clone().detach() # 数据从tensor格式转为ndarray # numpy.squeeze(a, axis=None),用于从数组的形状中删除单维条目,其中a表示输入的数组,axis用于指定需要删除的维度。如果axis为空,则删除所有单维度的条目。 image = image.numpy().squeeze() # PIL工具包 # tensor中数据格式为c*h*w,正常数据格式为h*w*c,transpose()将0、1、2代表tensor中三个维度c、h、w,转换为1、2、0即h、w、c格式 image = image.transpose(1, 2, 0) # 反标准化操作,均值u,标准差b,x = x*b+u image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) print(image) # clip()将数组中的元素值限制在给定的最小值和最大值之间,超出这个范围的值会被截断到最小值或最大值。 image = image.clip(0, 1) print(image) return image fig = plt.figure(figsize=(20, 20)) columns = 4 rows = 2 for idx in range(columns*rows): ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[]) plt.imshow(im_convert(images[idx])) ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red")) plt.show()

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950

验证结果展示:
随机一个batch的八张验证结果
随机一个batch的八张验证结果

相关知识

使用PyTorch实现对花朵的分类
搭建简单的神经网络——使用pytorch实现鸢尾花的分类
【大虾送书第二期】《Python机器学习:基于PyTorch和Scikit
pytorch 花朵的分类识别
创建虚拟环境并,创建pytorch 1.3.1
Pytorch入门——手把手教你MNIST手写数字识别
pytorch实现迁移训练
小白学摄影 让你镜头下的花儿不再普通
Python基于Pytorch Transformer实现对iris鸢尾花的分类预测,分别使用CPU和GPU训练
Pytorch介绍与linux、windows环境下安装

网址: 小白学Pytorch使用(4 https://www.huajiangbk.com/newsview1057020.html

所属分类:花卉
上一篇: 紫鸟浏览器如何转移账号
下一篇: 细粒度分类 CUB

推荐分享