Pytorch项目实战-2:花卉分类

一、前言

在深度学习项目中,数据集的处理和模型的训练、测试、预测是关键环节。本文将为小白详细介绍从数据集搜集、清洗、划分到模型训练、测试、预测以及模型结构查看的全流程,附带代码和操作说明,让你轻松上手!

二、数据集

二、数据集获取

2.1 自建数据集 vs 公开数据集

  • 自建数据集:适合本科毕设、大作业等小规模场景,可通过自己拍摄爬虫爬取(如百度图片)构建。
  • 公开数据集:适合专业研究,例如医学图像分割可从ISIC Archive获取。

2.2 百度图片爬虫实战(附代码)

代码文件:data_get.py

# -*- coding: utf-8 -*-
import requests
import re
import osheaders = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1=60张,2=120张):')
list_1 = []for i in range(int(x)):name_1 = os.getcwd()name_2 = os.path.join(name_1, 'data/' + name)url = f'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word={name}&pn={i*30}'res = requests.get(url, headers=headers)htlm_1 = res.content.decode()a = re.findall('"objURL":"(.*?)",', htlm_1)if not os.path.exists(name_2):os.makedirs(name_2)for b in a:try:b_2 = re.findall('https:(.*?)&', b)[0]  # 提取图片URLif b_2 not in list_1:num += 1img = requests.get(b)save_path = os.path.join(name_1, 'data/' + name, f'{name}{num}.jpg')with open(save_path, 'ab') as f:f.write(img.content)print(f'---------正在下载第{num}张图片----------')list_1.append(b_2)else:num_1 += 1  # 统计重复图片except:print(f'---------第{num}张图片无法下载----------')num_2 += 1  # 统计失败图片print(f'下载完成!总共下载{num+num_1+num_2}张,成功{num}张,重复{num_1}张,失败{num_2}张')

使用步骤

  1. 保存代码为data_get.py
  2. 运行后输入图片类别(如 “向日葵”)和数量(1 或 2)
  3. 图片会自动保存到data/类别名目录下

 三、数据集清洗(解决中文路径和坏图问题)

3.1 为什么需要清洗?

  • OpenCV 对中文路径支持差,会导致读取错误
  • 爬取的图片可能包含损坏文件(无法读取的坏图)

3.2 清洗代码(data_clean.py)

import shutil
import cv2
import os
import numpy as np
from tqdm import tqdmdef cv_imread(file_path, type=-1):"""支持中文路径读取图片"""img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if type==0 else imgdef cv_imwrite(file_path, cv_img, is_gray=True):"""支持中文路径保存图片"""if len(cv_img.shape)==3 and is_gray:cv_img = cv_img[:, :, 0]cv2.imencode(os.path.splitext(file_path)[1], cv_img)[1].tofile(file_path)def data_clean(src_folder, english_name):clean_folder = f{src_folder}_cleanedif os.path.isdir(clean_folder):shutil.rmtree(clean_folder)  # 删除已存在目录os.mkdir(clean_folder)image_names = os.listdir(src_folder)with tqdm(total=len(image_names)) as pabr:for i, name in enumerate(image_names):path = os.path.join(src_folder, name)try:img = cv_imread(path)# 保存为英文名称的JPG图片save_name = f{english_name}_{i}.jpgsave_path = os.path.join(clean_folder, save_name)cv_imwrite(save_path, img, is_gray=False)except:print(f{name}是坏图)pabr.update(1)if __name__ == __main__:data_clean(src_folder=D:/数据集/向日葵, english_name=sunflowers)  # 替换为你的路径

运行结果

  • 生成原目录_cleaned文件夹,存放清洗后的图片
  • 自动跳过坏图,重命名为英文(如sunflowers_0.jpg

四、数据集划分(6:2:2 比例)

4.1 适用场景

  • 当数据集未区分训练集、验证集、测试集时使用
  • 要求:图片按类别存放在子目录下(如data/向日葵data/玫瑰

4.2 划分代码(data_split.py)

import os
import shutil
import random
from tqdm import tqdmdef split_data(src_dir, save_dir, ratios=[0.6, 0.2, 0.2]):os.makedirs(save_dir, exist_ok=True)categories = os.listdir(src_dir)for cate in categories:cate_path = os.path.join(src_dir, cate)imgs = os.listdir(cate_path)random.shuffle(imgs)total = len(imgs)# 计算划分索引train_idx = int(total * ratios[0])val_idx = train_idx + int(total * ratios[1])# 划分数据集for phase, start, end in zip(['train', 'val', 'test'], [0, train_idx, val_idx]):phase_dir = os.path.join(save_dir, phase, cate)os.makedirs(phase_dir, exist_ok=True)for img in tqdm(imgs[start:end], desc=fProcessing {phase} {cate}):src_img = os.path.join(cate_path, img)dest_img = os.path.join(phase_dir, img)shutil.copyfile(src_img, dest_img)if __name__ == __main__:src_dir = D:/数据集_cleaned  # 清洗后的数据集路径save_dir = D:/数据集_split  # 划分结果保存路径split_data(src_dir, save_dir)

关键操作

  1. 修改src_dir为清洗后的数据集路径
  2. 运行后生成save_dir/split目录,包含train/val/test子目录
  3. 比例可在ratios参数中调整(总和需为 1)

五、模型训练(以 ResNet50 为例)

5.1 准备工作

一开始执行之前会有一个会需要下载预训练模型到指定目录,由于众所周知的原因,大家需要提前先把模型下载下来放置到这个目录,这个大家自行探索。

image20221201114528829

右键直接运行train.py就可以开始训练模型,代码首先会输出模型的基本信息(模型有几个卷积层、池化层、全连接层构成)和运行的记录。

  • 下载预训练模型(如 ResNet50),放入指定目录(代码中标记TODO
  • 确保数据集划分正确(训练集路径需对应)

5.2 开始训练

from torchutils import *
from torchvision import datasets, models, transforms
import os.path as osp
import os
if torch.cuda.is_available():device = torch.device('cuda:0')
else:device = torch.device('cpu')
print(f'Using device: {device}')
# 固定随机种子,保证实验结果是可以复现的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
data_path = r"G:\code\2023_pytorch110_classification_42-master\flowers_5_split" # todo 数据集路径
# 注: 执行之前请先划分数据集
# 超参数设置
params = {# 'model': 'vit_tiny_patch16_224',  # 选择预训练模型# 'model': 'resnet50d',  # 选择预训练模型'model': 'efficientnet_b3a',  # 选择预训练模型"img_size": 224,  # 图片输入大小"train_dir": osp.join(data_path, "train"),  # todo 训练集路径"val_dir": osp.join(data_path, "val"),  # todo 验证集路径'device': device,  # 设备'lr': 1e-3,  # 学习率'batch_size': 4,  # 批次大小'num_workers': 0,  # 进程'epochs': 10,  # 轮数"save_dir": "../checkpoints/",  # todo 保存路径"pretrained": True,"num_classes": len(os.listdir(osp.join(data_path, "train"))),  # 类别数目, 自适应获取类别数目'weight_decay': 1e-5  # 学习率衰减
}# 定义模型
class SELFMODEL(nn.Module):def __init__(self, model_name=params['model'], out_features=params['num_classes'],pretrained=True):super().__init__()self.model = timm.create_model(model_name, pretrained=pretrained)  # 从预训练的库中加载模型# self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path="pretrained/resnet50d_ra2-464e36ba.pth")  # 从预训练的库中加载模型# classifierif model_name[:3] == "res":n_features = self.model.fc.in_features  # 修改全连接层数目self.model.fc = nn.Linear(n_features, out_features)  # 修改为本任务对应的类别数目elif model_name[:3] == "vit":n_features = self.model.head.in_features  # 修改全连接层数目self.model.head = nn.Linear(n_features, out_features)  # 修改为本任务对应的类别数目else:n_features = self.model.classifier.in_featuresself.model.classifier = nn.Linear(n_features, out_features)# resnet修改最后的全链接层print(self.model)  # 返回模型def forward(self, x):  # 前向传播x = self.model(x)return x# 定义训练流程
def train(train_loader, model, criterion, optimizer, epoch, params):metric_monitor = MetricMonitor()  # 设置指标监视器model.train()  # 模型设置为训练模型nBatch = len(train_loader)stream = tqdm(train_loader)for i, (images, target) in enumerate(stream, start=1):  # 开始训练images = images.to(params['device'], non_blocking=True)  # 加载数据target = target.to(params['device'], non_blocking=True)  # 加载模型output = model(images)  # 数据送入模型进行前向传播loss = criterion(output, target.long())  # 计算损失f1_macro = calculate_f1_macro(output, target)  # 计算f1分数recall_macro = calculate_recall_macro(output, target)  # 计算recall分数acc = accuracy(output, target)  # 计算准确率分数metric_monitor.update('Loss', loss.item())  # 更新损失metric_monitor.update('F1', f1_macro)  # 更新f1metric_monitor.update('Recall', recall_macro)  # 更新recallmetric_monitor.update('Accuracy', acc)  # 更新准确率optimizer.zero_grad()  # 清空学习率loss.backward()  # 损失反向传播optimizer.step()  # 更新优化器lr = adjust_learning_rate(optimizer, epoch, params, i, nBatch)  # 调整学习率stream.set_description(  # 更新进度条"Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch,metric_monitor=metric_monitor))return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"]  # 返回结果# 定义验证流程
def validate(val_loader, model, criterion, epoch, params):metric_monitor = MetricMonitor()  # 验证流程model.eval()  # 模型设置为验证格式stream = tqdm(val_loader)  # 设置进度条with torch.no_grad():  # 开始推理for i, (images, target) in enumerate(stream, start=1):images = images.to(params['device'], non_blocking=True)  # 读取图片target = target.to(params['device'], non_blocking=True)  # 读取标签output = model(images)  # 前向传播loss = criterion(output, target.long())  # 计算损失f1_macro = calculate_f1_macro(output, target)  # 计算f1分数recall_macro = calculate_recall_macro(output, target)  # 计算recall分数acc = accuracy(output, target)  # 计算accmetric_monitor.update('Loss', loss.item())  # 后面基本都是更新进度条的操作metric_monitor.update('F1', f1_macro)metric_monitor.update("Recall", recall_macro)metric_monitor.update('Accuracy', acc)stream.set_description("Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch,metric_monitor=metric_monitor))return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"]# 展示训练过程的曲线
def show_loss_acc(acc, loss, val_acc, val_loss, sava_dir):# 从history中提取模型训练集和验证集准确率信息和误差信息# 按照上下结构将图画输出plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.ylabel('Accuracy')plt.ylim([min(plt.ylim()), 1])plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.ylabel('Cross Entropy')plt.title('Training and Validation Loss')plt.xlabel('epoch')# 保存在savedir目录下。save_path = osp.join(save_dir, "results.png")plt.savefig(save_path, dpi=100)if __name__ == '__main__':accs = []losss = []val_accs = []val_losss = []data_transforms = get_torch_transforms(img_size=params["img_size"])  # 获取图像预处理方式train_transforms = data_transforms['train']  # 训练集数据处理方式valid_transforms = data_transforms['val']  # 验证集数据集处理方式train_dataset = datasets.ImageFolder(params["train_dir"], train_transforms)  # 加载训练集valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加载验证集if params['pretrained'] == True:save_dir = osp.join(params['save_dir'], params['model']+"_pretrained_" + str(params["img_size"]))  # 设置模型保存路径else:save_dir = osp.join(params['save_dir'], params['model'] + "_nopretrained_" + str(params["img_size"]))  # 设置模型保存路径if not osp.isdir(save_dir):  # 如果保存路径不存在的话就创建os.makedirs(save_dir)  #print("save dir {} created".format(save_dir))train_loader = DataLoader(  # 按照批次加载训练集train_dataset, batch_size=params['batch_size'], shuffle=True,num_workers=params['num_workers'], pin_memory=True,)val_loader = DataLoader(  # 按照批次加载验证集valid_dataset, batch_size=params['batch_size'], shuffle=False,num_workers=params['num_workers'], pin_memory=True,)print(train_dataset.classes)model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],pretrained=params['pretrained']) # 加载模型# model = nn.DataParallel(model)  # 模型并行化,提高模型的速度# resnet50d_1epochs_accuracy0.50424_weights.pthmodel = model.to(params['device'])  # 模型部署到设备上criterion = nn.CrossEntropyLoss().to(params['device'])  # 设置损失函数optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])  # 设置优化器# 损失函数和优化器可以自行设置修改。# criterion = nn.CrossEntropyLoss().to(params['device'])  # 设置损失函数# optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])  # 设置优化器best_acc = 0.0  # 记录最好的准确率# 只保存最好的那个模型。for epoch in range(1, params['epochs'] + 1):  # 开始训练acc, loss = train(train_loader, model, criterion, optimizer, epoch, params)val_acc, val_loss = validate(val_loader, model, criterion, epoch, params)accs.append(acc)losss.append(loss)val_accs.append(val_acc)val_losss.append(val_loss)if val_acc >= best_acc:# 保存的时候设置一个保存的间隔,或者就按照目前的情况,如果前面的比后面的效果好,就保存一下。# 按照间隔保存的话得不到最好的模型。save_path = osp.join(save_dir, f"{params['model']}_{epoch}epochs_accuracy{acc:.5f}_weights.pth")torch.save(model.state_dict(), save_path)best_acc = val_accshow_loss_acc(accs, losss, val_accs, val_losss, save_dir)print("训练已完成,模型和训练日志保存在: {}".format(save_dir))

运行结果:

  • 输出模型结构(卷积层 / 池化层 / 全连接层)
  • 保存训练曲线(acc.pngloss.png
  • 自动保存最优模型到指定目录

六、模型测试与预测

6.1 测试代码(test.py)

python

from torchutils import *
from torchvision import datasets, models, transforms
import os.path as osp
import os
from train import SELFMODELif torch.cuda.is_available():device = torch.device('cuda:0')
else:device = torch.device('cpu')
print(f'Using device: {device}')
# 固定随机种子,保证实验结果是可以复现的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Truedata_path = "../flowers_data_split"  # todo 修改为数据集根目录
model_path = "../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth"  # todo 模型地址
model_name = 'resnet50d'  # todo 模型名称
img_size = 224  # todo 数据集训练时输入模型的大小
# 注: 执行之前请先划分数据集
# 超参数设置
params = {# 'model': 'vit_tiny_patch16_224',  # 选择预训练模型# 'model': 'efficientnet_b3a',  # 选择预训练模型'model': model_name,  # 选择预训练模型"img_size": img_size,  # 图片输入大小"test_dir": osp.join(data_path, "test"),  # todo 测试集子目录'device': device,  # 设备'batch_size': 4,  # 批次大小'num_workers': 0,  # 进程"num_classes": len(os.listdir(osp.join(data_path, "train"))),  # 类别数目, 自适应获取类别数目
}def test(val_loader, model, params, class_names):metric_monitor = MetricMonitor()  # 验证流程model.eval()  # 模型设置为验证格式stream = tqdm(val_loader)  # 设置进度条# 对模型分开进行推理test_real_labels = []test_pre_labels = []with torch.no_grad():  # 开始推理for i, (images, target) in enumerate(stream, start=1):images = images.to(params['device'], non_blocking=True)  # 读取图片target = target.to(params['device'], non_blocking=True)  # 读取标签output = model(images)  # 前向传播# loss = criterion(output, target.long())  # 计算损失# print(output)target_numpy = target.cpu().numpy()y_pred = torch.softmax(output, dim=1)y_pred = torch.argmax(y_pred, dim=1).cpu().numpy()test_real_labels.extend(target_numpy)test_pre_labels.extend(y_pred)# print(target_numpy)# print(y_pred)f1_macro = calculate_f1_macro(output, target)  # 计算f1分数recall_macro = calculate_recall_macro(output, target)  # 计算recall分数acc = accuracy(output, target)  # 计算acc# metric_monitor.update('Loss', loss.item())  # 后面基本都是更新进度条的操作metric_monitor.update('F1', f1_macro)metric_monitor.update("Recall", recall_macro)metric_monitor.update('Accuracy', acc)stream.set_description("mode: {epoch}.  {metric_monitor}".format(epoch="test",metric_monitor=metric_monitor))class_names_length = len(class_names)heat_maps = np.zeros((class_names_length, class_names_length))for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1# print(heat_maps)heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)# print(heat_maps_sum)# print()heat_maps_float = heat_maps / heat_maps_sum# print(heat_maps_float)# title, x_labels, y_labels, harvestshow_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,save_name="record/heatmap_{}.png".format(model_name))# 加上模型名称return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['F1']["avg"], \metric_monitor.metrics['Recall']["avg"]def show_heatmaps(title, x_labels, y_labels, harvest, save_name):# 这里是创建一个画布fig, ax = plt.subplots()# cmap https://blog.csdn.net/ztf312/article/details/102474190im = ax.imshow(harvest, cmap="OrRd")# 这里是修改标签# We want to show all ticks...ax.set_xticks(np.arange(len(y_labels)))ax.set_yticks(np.arange(len(x_labels)))# ... and label them with the respective list entriesax.set_xticklabels(y_labels)ax.set_yticklabels(x_labels)# 因为x轴的标签太长了,需要旋转一下,更加好看# Rotate the tick labels and set their alignment.plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")# 添加每个热力块的具体数值# Loop over data dimensions and create text annotations.for i in range(len(x_labels)):for j in range(len(y_labels)):text = ax.text(j, i, round(harvest[i, j], 2),ha="center", va="center", color="black")ax.set_xlabel("Predict label")ax.set_ylabel("Actual label")ax.set_title(title)fig.tight_layout()plt.colorbar(im)plt.savefig(save_name, dpi=100)# plt.show()if __name__ == '__main__':data_transforms = get_torch_transforms(img_size=params["img_size"])  # 获取图像预处理方式# train_transforms = data_transforms['train']  # 训练集数据处理方式valid_transforms = data_transforms['val']  # 验证集数据集处理方式# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加载验证集# print(valid_dataset)test_dataset = datasets.ImageFolder(params["test_dir"], valid_transforms)class_names = test_dataset.classesprint(class_names)# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms)  # 加载验证集test_loader = DataLoader(  # 按照批次加载训练集test_dataset, batch_size=params['batch_size'], shuffle=True,num_workers=params['num_workers'], pin_memory=True,)# 加载模型model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],pretrained=False)  # 加载模型结构,加载模型结构过程中pretrained设置为False即可。weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 指标上的测试结果包含三个方面,分别是acc f1 和 recall, 除此之外,应该还有相应的热力图输出,整体会比较好看一些。acc, f1, recall = test(test_loader, model, params, class_names)print("测试结果:")print(f"acc: {acc}, F1: {f1}, recall: {recall}")print("测试完成,heatmap保存在{}下".format("record"))

6.2 图片预测(predict.py)

import torch
# from train_resnet import SelfNet
from train import SELFMODEL
import os
import os.path as osp
import shutil
import torch.nn as nn
from PIL import Image
from torchutils import get_torch_transformsif torch.cuda.is_available():device = torch.device('cuda')
else:device = torch.device('cpu')model_path = "../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth"  # todo  模型路径
classes_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']  # todo 类名
img_size = 224  # todo 图片大小
model_name = "resnet50d"  # todo 模型名称
num_classes = len(classes_names)  # todo 类别数目def predict_batch(model_path, target_dir, save_dir):data_transforms = get_torch_transforms(img_size=img_size)valid_transforms = data_transforms['val']# 加载网络model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)# model = nn.DataParallel(model)weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 读取图片image_names = os.listdir(target_dir)for i, image_name in enumerate(image_names):image_path = osp.join(target_dir, image_name)img = Image.open(image_path)img = valid_transforms(img)img = img.unsqueeze(0)img = img.to(device)output = model(img)label_id = torch.argmax(output).item()predict_name = classes_names[label_id]save_path = osp.join(save_dir, predict_name)if not osp.isdir(save_path):os.makedirs(save_path)shutil.copy(image_path, save_path)print(f"{i + 1}: {image_name} result {predict_name}")def predict_single(model_path, image_path):data_transforms = get_torch_transforms(img_size=img_size)# train_transforms = data_transforms['train']valid_transforms = data_transforms['val']# 加载网络model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)# model = nn.DataParallel(model)weights = torch.load(model_path)model.load_state_dict(weights)model.eval()model.to(device)# 读取图片img = Image.open(image_path)img = valid_transforms(img)img = img.unsqueeze(0)img = img.to(device)output = model(img)label_id = torch.argmax(output).item()predict_name = classes_names[label_id]print(f"{image_path}'s result is {predict_name}")if __name__ == '__main__':# 批量预测函数predict_batch(model_path=model_path,target_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini",save_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini_result")# 单张图片预测函数# predict_single(model_path=model_path, image_path="images/test_imgs/506659320_6fac46551e.jpg")

七、模型结构与参数量查看

7.1 查看模型结构(Netron 工具)

  1. 将模型转换为 ONNX 格式(代码utils/export_onnx.py):
import numpy as np
import onnxruntime
from PIL import Imageclass_names = {'0': '雏菊', '1': '蒲公英', '2': '玫瑰', '3': '向日葵', '4': '郁金香'}# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
# 预测图片
session = onnxruntime.InferenceSession(r"C:\Users\nongc\Desktop\ImageClassifier.onnx")def process_image(image_path):# 读取测试数据img = Image.open(image_path)# Resize,thumbnail方法只能进行缩小,所以进行了判断if img.size[0] > img.size[1]:img.thumbnail((10000, 256))else:img.thumbnail((256, 10000))# Crop操作left_margin = (img.width - 224) / 2bottom_margin = (img.height - 224) / 2right_margin = left_margin + 224top_margin = bottom_margin + 224img = img.crop((left_margin, bottom_margin, right_margin,top_margin))# img.save('thumb.jpg')# 相同的预处理方法img = np.array(img) / 255mean = np.array([0.485, 0.456, 0.406])  # provided meanstd = np.array([0.229, 0.224, 0.225])  # provided stdimg = (img - mean) / std# 注意颜色通道应该放在第一个位置img = img.transpose((2, 0, 1))return imgimage_path = r"C:\Users\nongc\Desktop\百度云下载\2023_pytorch110_classification_42-master\2023_pytorch110_classification_42-master\flowers_5\roses\99383371_37a5ac12a3_n.jpg"  # '1':
img = process_image(image_path)
img = np.expand_dims(img, 0)outputs = session.run([], {"modelInput": img.astype('float32')})
result_index = int(np.argmax(np.squeeze(outputs)))
result = class_names['%d' % result_index]  # 获得对应的名称print(np.squeeze(outputs), '\n', img.shape)
print(f"预测种类为: {result} 对应索引为:{np.argmax(np.squeeze(outputs))}")
# print(np.min(outputs),np.argmin(np.squeeze(outputs)),np.max(outputs))

 打开Netron 官网,拖入resnet50.onnx即可可视化模型结构。

7.2 计算参数量(get_flops.py)

import torch
from torchstat import stat
from train import SELFMODELif torch.cuda.is_available():device = torch.device('cuda')
else:device = torch.device('cpu')
model_name = "resnet50d" # todo 模型名称
num_classes = 5 # todo 类别数目
model_path = "../../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth" # todo 模型地址
model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)
weights = torch.load(model_path)
model.load_state_dict(weights)
model.eval()
stat(model, (3, 224, 224)) # 后面的224表示模型的输入大小

八、总结

本文覆盖了深度学习项目的核心流程:数据获取→清洗→划分→训练→测试→预测,并提供了可直接运行的代码和详细操作说明。对于小白来说,建议先从简单数据集(如花卉分类)入手,逐步熟悉每个环节,遇到问题可参考代码中的TODO注释和报错信息排查。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/84074.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React Flow 边事件处理实战:鼠标事件、键盘操作及连接规则设置(附完整代码)

本文为《React Agent:从零开始构建 AI 智能体》专栏系列文章。 专栏地址:https://blog.csdn.net/suiyingy/category_12933485.html。项目地址:https://gitee.com/fgai/react-agent(含完整代码示​例与实战源)。完整介绍…

java小结(一)

java(上) 模块一 1.JDK,JRE,JVM 知识点 核心内容 易混淆点 JDK定义 Java Development Kit(Java开发工具包),包含开发所需全部工具 JDK包含JRE的关系容易混淆 JRE定义 Java Runtime Environment(Jav…

ddns-go安装介绍-强大的ipv6动态域名解析神器-家庭云计算专家

ddns-go 是一款轻量级开源动态域名解析工具,专注于解决动态IP环境下的域名绑定问题,尤其适配IPv6网络环境。其核心功能包括: 1.IPv6动态解析:自动检测本地IPv6地址变化(支持网卡、接口或命令获取)&#xf…

Docker-mongodb

拉取 MongoDB 镜像: docker pull mongo 创建容器并设置用户: 要挂载本地数据目录,请替换此路径: /Users/Allen/Env/AllenDocker/mongodb/data/db docker run -d --name local-mongodb \-e MONGO_INITDB_ROOT_USERNAMEadmin \-e MONGO_INITDB_ROOT_PA…

WooCommerce缓存教程 – 如何防止缓存破坏你的WooCommerce网站?

我们在以前的文章中探讨过如何加快你的WordPress网站的速度,并研究过各种形式的缓存。 然而,像那些使用WooCommerce的动态电子商务网站,在让缓存正常工作方面往往会面临重大挑战。 在本指南中,我们将告诉你如何为WooCommerce设置…

贪心算法 Part04

总结下重叠区间问题 LC 452. 用最少数量的箭引爆气球 和 LC 435. 无重叠区间 本质上是一样的。 LC 452. 用最少数量的箭引爆气球 是求n个区间当中 , 区间的种类数量 k。此处可以理解为,重叠在一起的区间属于同一品种,没有重叠的区间当然…

云原生CD工具-Argocd+ArgoRollout入门到精通

第一章 Argo CD简介 课时1.1 Argo产品介绍 ARGO官网地址:https://argoproj.github.io/ 旗下产品有: Argo Workflows、ArgoCD 、Argo Rollouts 、Argo Events 课时1.2 什么是Argo CD Argo CD 是一个开源的持续交付工具, 是 Kubernetes 的声明式 GitOps 持续交付工具。专…

数据分析与应用---数据可视化基础

目录 Matplotlib基础绘图 (一)、pyplot绘图基础语法与常用参数 1、pyplot基础语法 (1) 创建画布与创建子图 (2) 添加画布内容 (3) 保存与显示图形 案例代码 2. 设置pyplot的动态rc参数 (二)、使用Matplotlib绘制进阶图形 1. 绘制散点图----scatter 2. 绘制折线…

PP-YOLOE-SOD学习笔记1

项目:基于PP-YOLOE-SOD的无人机航拍图像检测案例全流程实操 - 飞桨AI Studio星河社区 一、安装环境 先准备新环境py>3.9 1.先cd到源代码的根目录下 2.pip install -r requirements.txt 3.python setup.py install 这一步需要看自己的GPU情况,去飞浆…

力扣HOT100之二叉树:114. 二叉树展开为链表

这道题自己尝试着做了一下,感觉还是得用递归来做比较简单,但是一直想的是用前序遍历来构造链表,导致怎么做都不对,去看了下灵神的题解,然后问了下GPT,现在终于弄明白了。虽然构造出来的链表的排列顺序是按照…

Spring Boot 注解 @ConditionalOnMissingBean是什么

一句话总结: ConditionalOnMissingBean 是 Spring Boot 提供的一个 条件注解(Conditional Annotation),意思是: 只有当 Spring 容器中 不存在 某个 Bean 时,当前的 Bean 或配置才会被加载。 这是一种典型的…

PyInstaller 如何在mac电脑上生成在window上可执行的exe文件

PyInstaller跨平台打包限制 PyInstaller 无法直接从macOS生成Windows可执行文件,因为它需要访问目标平台的系统库和Python环境来构建可执行文件。要在macOS上为Windows打包Python应用,需要通过以下方法之一: 方法一:使用虚拟机或…

零基础设计模式——创建型模式 - 抽象工厂模式

第二部分:创建型模式 - 抽象工厂模式 (Abstract Factory Pattern) 我们已经学习了单例模式(保证唯一实例)和工厂方法模式(延迟创建到子类)。现在,我们来探讨创建型模式中更为复杂和强大的一个——抽象工厂…

【通用智能体】Serper API 详解:搜索引擎数据获取的核心工具

Serper API 详解:搜索引擎数据获取的核心工具 一、Serper API 的定义与核心功能二、技术架构与核心优势2.1 技术实现原理2.2 对比传统方案的突破性优势 三、典型应用场景与代码示例3.1 SEO 监控系统3.2 竞品广告分析 四、使用成本与配额策略五、开发者注意事项六、替…

Flask-SQLAlchemy核心概念:模型类与数据库表、类属性与表字段、外键与关系映射

前置阅读,关于Flask-SQLAlchemy支持哪些数据库及基本配置,链接:Flask-SQLAlchemy_数据库配置 摘要 本文以一段典型的 SQLAlchemy 代码示例为引入,阐述以下核心概念: 模型类(Model Class) ↔ 数…

野火鲁班猫(arrch64架构debian)从零实现用MobileFaceNet算法进行实时人脸识别(四)安装RKNN Toolkit2

RKNN Toolkit2是用来将onnx模型转成rknn专用模型,并可通过RKNN Toolkit Lite2或者RKNPU调用NPU进行加速计算的工具。 一开始我安装很多次都无法成功安装。后来跟售后技术对接,必须是PC平台的Linux环境才可以。我的电脑是windows,所以我需要用…

基于深度学习的工件检测系统设计与实现

在工业自动化领域,工件检测一直是提高生产效率和产品质量的关键环节。传统的人工检测方法不仅效率低下,而且容易受到主观因素的影响,导致误判率较高。随着深度学习技术的飞速发展,基于图像识别的自动检测系统逐渐成为研究热点。今…

CyberSecAsia专访CertiK首席安全官:区块链行业亟需“安全优先”开发范式

近日,权威网络安全媒体CyberSecAsia发布了对CertiK首席安全官Wang Tielei博士的专访,双方围绕企业在进军区块链领域时所面临的关键安全风险与防御策略展开深入探讨。 Wang博士在采访中指出,跨链桥攻击、智能合约漏洞以及私钥管理不当&#x…

Google C++ Style Guide 谷歌 C++编码风格指南,深入理解华为与谷歌的编程规范——C和C++实践指南

Google C 编程风格指南 Release Apr 07, 2017 0. ᡿享 ⡾ᵢ 4.45 ৕֒㘻 Benjy Weinberger, Craig Silverstein, Gregory Eitzmann, Mark Mentovai, Tashana Landray 㘱䈇 YuleFox, Yang.Y, acgtyrant, lilinsanity 亯ⴤѱ享 • Google Style Guide • Google 开源…

当科技邂逅浪漫:在Codigger的世界里,遇见“爱”

520,一个充满爱意的日子,人们用各种方式表达对彼此的深情。而在科技的世界里,我们也正经历着一场特别的邂逅——Codigger,一个分布式操作系统的诞生,正在以它独特的方式,重新定义我们与技术的关系。 Codigg…