首页 分享 Pytorch: 在预训练模型中输入的数据预处理

Pytorch: 在预训练模型中输入的数据预处理

来源:花匠小妙招 时间:2025-05-13 20:22

我们经常看到:

transform = transforms.Compose([

transforms.RandomResizedCrop(100),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

这里使用的mean= [ 0.485,0.456,0.406 ]和STD=[ 0.229,0.224,0.225 ]进行归一化.

这里的数字是在预训练模型训练的时候采用的格式,我们需要和预训练模型保持相同的格式。

同时在pytorch官方介绍中说明了预训练模型都是采用【0,1】标准分布的图像训练。

以下是pytorch的样例所使用的mean和std

https://github.com/pytorch/examples/tree/42e5b996718797e45c46a25c55b031e6768f8440

附上计算mean和std的代码:

transform = transforms.Compose([

transforms.ToPILImage(),

transforms.ToTensor()

])

dataloader = torch.utils.data.DataLoader(*torch_dataset*, batch_size=4096, shuffle=False, num_workers=4)

pop_mean = []

pop_std0 = []

pop_std1 = []

for i, data in enumerate(dataloader, 0):

numpy_image = data['image'].numpy()

batch_mean = np.mean(numpy_image, axis=(0,2,3))

batch_std0 = np.std(numpy_image, axis=(0,2,3))

batch_std1 = np.std(numpy_image, axis=(0,2,3), ddof=1)

pop_mean.append(batch_mean)

pop_std0.append(batch_std0)

pop_std1.append(batch_std1)

pop_mean = np.array(pop_mean).mean(axis=0)

pop_std0 = np.array(pop_std0).mean(axis=0)

pop_std1 = np.array(pop_std1).mean(axis=0)

计算出来只需要添加即可:

transform = transforms.Compose([

transforms.ToPILImage(),

transforms.ToTensor(),

transforms.Normalize(mean=*your_calculated_mean*, std=*your_calculated_std*)

])

分batch计算mean,再取平均。

相关知识

Pytorch: 在预训练模型中输入的数据预处理
pytorch学习之加载预训练模型
pytorch实现迁移训练
pytorch环境下kaggle数据集花种类识别
PyTorch环境下的柑橘病变图像识别与数据集处理
使用pytorch中预训练模型VGG19获取图像特征,得到图像embedding
ResNet残差网络在PyTorch中的实现——训练花卉分类器
【基于PyTorch实现经典网络架构的花卉图像分类模型】
pytorch——AlexNet——训练花分类数据集
玫瑰花图片数据集助力深度学习模型训练

网址: Pytorch: 在预训练模型中输入的数据预处理 https://www.huajiangbk.com/newsview1947669.html

所属分类:花卉
上一篇: 常见水培植物有哪些
下一篇: 垂盆草的扦插方法(轻松掌握垂盆草

推荐分享