深度学习入门代码详细注释-ResNet18分类蚂蚁蜜蜂

       本项目将基于PyTorch平台迁移ResNet18模型。该模型原采用ImageNet数据集(含1000个图像类别)进行训练。我们将尝试运用该模型对蚂蚁和蜜蜂进行分类(这两个类别未包含在原训练数据集中)。

       本文的原始代码参考于博客深度学习入门项目——附代码(持续更新) - 知乎,但是这位博主只给出了代码,而没有对代码进行一些必要的注释,这对刚入门的菜鸟新手来说不太友好,所以在这里,我对该代码做了一些详细的注释,希望能够帮助到和我一样新入门的菜鸟。也同时感谢原作者对代码整理所付出的劳动!

#加载所需要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLRimport torchvision
from torchvision import datasets
from torchvision import models
from torchvision import transformsimport numpy as npfrom io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFileimport matplotlib.pyplot as plt#调用urllib.request.urlopen从下载网址https://pytorch.tips/bee-zip中下载数据,并生成对象#zipresp;
#用IO流来进行操作,无需将zip文件下载到本地磁盘上
zipurl='https://pytorch.tips/bee-zip'
with urlopen(zipurl) as zipresp:with ZipFile(BytesIO(zipresp.read())) as zfile:zfile.extractall('./data')#定义训练集的变换
train_transforms = transforms.Compose([       #transforms.Compose 是一个工具,用于将多个预
#处理操作组合成一个完整的流程。它接受一个列表,列表中的每个元素是一个预处理操作。# 随机裁剪并调整大小到 224x224,比如原图像大小为512x512,是一个随机操作,每次裁剪的区域可
#能不同,裁剪出224x224transforms.RandomResizedCrop(224),         # 用于训练集,增加数据的多样性,帮助模型学习
#到更多的特征。# 随机水平翻转,概率为 0.5transforms.RandomHorizontalFlip(),     #它通过0.5的概率随机水平翻转图像,增加了数据的多样性,有助于提高模型的泛化能力和减少过拟合。一张猫的图像无论向左看还是向右看,都是#猫通过随机水平翻转,模型可以学习到更多样的图像特征。# 将图片转换为张量transforms.ToTensor(),# 标准化处理,使用预定义的均值和标准差transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet  RGB三个通道的均值std=[0.229, 0.224, 0.225]        # ImageNet  RGB三个通道的标准差)])
#定义验证集的变换
val_transforms = transforms.Compose([# 将图片调整为 256x256transforms.Resize(256),               #是一个确定性操作,每次调整大小的结果是相同的。这
#确保了验证集的图像在每次运行时都保持一致,从而保证验证结果的稳定性和可重复性。# 从中心裁剪出 224x224 的区域transforms.CenterCrop(224),    #这种组合操作确保了验证集的图像在每次运行时都保持一致,同
#时也能保证输入到模型的图像大小一致。# 将图片转换为张量transforms.ToTensor(),    #将图像从 PIL 图像或 NumPy 数组转换为 PyTorch 张量。转换后的张
#量形状为 (C, H, W),其中 C 是通道数,H 是高度,W 是宽度。像素值会被归一化到 [0, 1] 范围。# 标准化处理,使用预定义的均值和标准差transforms.Normalize(mean=[0.485, 0.456, 0.406],    # ImageNet RGB三个通道的均值std=[0.229, 0.224, 0.225]         # ImageNet RGB三个通道的标准差)])
#加载数据
#数据集
train_dataset = datasets.ImageFolder(root = './data/hymenoptera_data/train',transform = train_transforms)val_dataset = datasets.ImageFolder(root = './data/hymenoptera_data/val',transform = val_transforms)#数据加载器
train_loader = DataLoader(train_dataset,batch_size=4,     #每个批次加载 4 个样本。这意味着每次迭代会返回 4 ##个样本及其对应的标签shuffle=True,    #每个 epoch 开始时随机打乱数据。虽然验证集通常不需要打乱,但在某些情况下#打乱数据可以避免验证结果的偏差。num_workers=4                    #如果只用一个进程加载数据,GPU每次只能处理一张图像,处
#理完一张后再处理下一张。而使用 4 个子进程时,可以同时处理 4 张图像,这样可以显著减少数据加#载的#总时间。
)
val_loader = DataLoader(val_dataset,batch_size=4,shuffle=True,num_workers=4
)#生成ResNet18模型
#模型
model = models.resnet18(pretrained = True)#这是 PyTorch 提供的预定义 ResNet18 模型。pretrained=True:表示加载预训练的权重。这些权重是在 ImageNet 数据集上训练得到的,通常 #用于迁#移学习。
print(model.fc)                                                     #fc 是 ResNet18 模型中#的最后一个全连接层。默认情况下,ResNet18 的全连接层有 1000 个输出节点,对应于 ImageNet 数据
#集的 1000 个类别。
model.fc = nn.Linear(model.fc.in_features, 2) #model.fc.in_features:获取原全连接层的输入特#征数量。ResNet18 的全连接层输入特征数量为 512。nn.Linear(model.fc.in_features, 2):创建一
#个新的全连接层,输入特征数量保持不变,输出特征数量改为 2。这通常用于二分类任务。
print(model.fc)#定义超参数
model = model.to("cuda")#将模型的所有参数和缓冲区移动到 GPU 上。这使得模型可以在 GPU 上进行训#练,从而显著提高训练速度。Loss = nn.CrossEntropyLoss()#这是 PyTorch 提供的交叉熵损失函数,通常用于多分类任务。它结合了 LogSoftmax 和 NLLLoss,适用于分类任务。optim = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# optim.SGD:这是随机梯度下降(SGD)优化器。model.parameters():获取模型的所有可训练参#数。model.parameters():lr=0.001:设置学习率为 0.001。momentum=0.9:设置动量为 0.9,动量可#以帮助优化器更快地收敛,并减少振荡。exp_lr_scheduler = StepLR(optim, step_size=7, gamma=0.1)#StepLR:这是 PyTorch 提供的学习率#调度器,用于在训练过程中调整学习率。step_size=7:每 7 个 epoch,学习率会调整一次。
#gamma=0.1:每次调整学习率时,学习率会乘以 0.1。例如,如果初始学习率为 0.001,那么在第 7 个 #epoch 时,学习率会变为 0.0001。#可微调
#训练
num_epochs = 25 #定义了训练的总轮数,即模型将完整地遍历训练数据集的次数。在这里,设置为 25轮。
for epoch in range(num_epochs):#训练模型model.train()                                   #将模型设置为训练模式。这会影响某些层的
#行为,如 Dropout 和 BatchNorm,确保它们在训练时的行为与评估时不同。running_loss = 0.0                         #初始化running_loss:用于累计每个 epoch 的总#损失。running_corrects = 0                    #初始化running_corrects:用于累计每个 epoch 中
#正确预测的数量。for inputs, labels in train_loader: #train_loader:数据加载器,每次迭代返回一个批次的数
#据和标签。inputs = inputs.to("cuda")           #将输入数据移动到 GPU 上。labels = labels.to("cuda")            #将标签数据移动到 GPU 上。outputs = model(inputs) #将输入数据通过模型进行前向传播,得到模型的输出。_, preds = torch.max(outputs, 1) #获取模型输出的最大值及其索引。preds 是预测的类别索#引。loss = Loss(outputs, labels)          #计算模型输出和真实标签之间的损失值loss.backward()                            #反向传播,计算损失值的梯度,并将其传播回#模型的每个参数。optim.step()                                #根据计算得到的梯度更新模型的参数。optim.zero_grad()                       #清空之前的梯度,避免梯度累积# 累计损失和正确预测数running_loss += loss.item()/inputs.size(0)    #获取损失值的标量值。running_corrects += torch.sum(preds == labels.data)/inputs.size(0) #torch.sum(preds == labels.data):计算预测正确的数量。inputs.size(0):获取当前批次的样本数量。exp_lr_scheduler.step() #根据学习率调度器的设置更新学习率。train_epoch_loss = running_loss/len(train_loader) #计算每个 epoch 的平均损失。train_epoch_acc = running_corrects/len(train_loader)#计算每个 epoch 的平均准确率。#测试模型model.eval()running_loss = 0.0running_corrects = 0for inputs, labels in val_loader:inputs = inputs.to("cuda")labels = labels.to("cuda")outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = Loss(outputs, labels)running_loss += loss.item()/inputs.size(0)running_corrects += torch.sum(preds == labels.data)/inputs.size(0)epoch_loss = running_loss/len(val_loader)epoch_acc = running_corrects/len(val_loader)if((epoch+1)%5==0): #每隔 5 个 epoch,输出当前的训练和验证的损失和准确率。print("epoch:{},  ""Train_loss:{:.4f} Train_acc:{:.4f},  ""Loss:{:.4f}, Acc:{:.4f}".format(epoch+1, train_epoch_loss, train_epoch_acc, epoch_loss, epoch_acc))#测试可视化 这段代码定义了一个函数 imshow,用于将经过预处理的图像数据可视化,并显示其预测的类
#别。它还展示了如何使用这个函数来可视化验证集中的图像及其预测结果
def imshow(inp, title=None):#从 C-H-W 切换回 H-W-C 图像格式inp = inp.numpy().transpose((1, 2, 0)) #将输入的 PyTorch 张量转换为 NumPy 数组   从通道#优先格式(C-H-W)转换为高度-宽度-通道格式(H-W-C),以便使用 matplotlib 进行可视化。#撤销归一化mean = np.array([0.485, 0.456, 0.406]) #均值数组,用于撤销归一化。std = np.array([0.229, 0.224, 0.225])  #标准差数组,用于撤销归一化。inp = std * inp + mean  #撤销归一化操作,将图像数据恢复到原始范围。inp = np.clip(inp, 0, 1) #将图像数据裁剪到 [0, 1] 范围内,确保数据有效。plt.imshow(inp) #使用 matplotlib 的 imshow 函数显示图像。if title is not None:plt.title(title)inputs , classes = next(iter(val_loader))  #从验证集加载器中获取一个批次的数据和标签。
out = torchvision.utils.make_grid(inputs)    #将一个批次的图像拼接成一个网格,便于可视化
class_names = val_dataset.classes  #val_dataset.classes:获取验证集中的类别名称列表
outputs = model(inputs.to("cuda"))           #inputs.to("cuda"):将输入图像移动到 GPU。model(inputs.to("cuda")):将输入图像通过模型进行前向传播,得到模型的输出。_, preds = torch.max(outputs, 1)   #获取模型输出的最大值及其索引,preds 是预测的类别索引。imshow(out, title=[class_names[x] for x in preds])  #根据预测的类别索引获取对应的类别名称

                

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/89056.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/89056.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北京饮马河科技公司 Java 实习面经

北京饮马河科技公司 Java 实习面经 本文作者:程序员小白条 本站地址:https://xbt.xiaobaitiao.top 1) 面试官:我看你这块是有一个开源的项目,这个项目主要是做什么的? 我:主要两点是亮点&…

java基础(day07)

目录 OOP编程 方法 方法的调用: 在main入口函数中调用: 动态参数: 方法重载 OOP编程 方法 概念:指为获得某种东西或达到某种目的而采取的手段与行为方式。有时候被称作“方法”,有时候被称作“函数”。例如UUID.…

使用EasyExcel动态合并单元格(模板方法)

1、导入EasyExcel依赖<dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>4.0.3</version> </dependency>2、编写实体类Data publci class Student{ ExcelProperty("姓名")pri…

jenkins 流水线比较简单直观的

//全篇没用自定义变量pipeline {agent any// 使用工具自动配置Node.js环境tools {nodejs nodejs22 // 需在Jenkins全局工具中预配置该名称的Node.js安装}//下面拉取代码通过的是流水线片段生成的stages {stage(Checkout Code) {steps {git branch: release-v1.2.6,credentials…

CV目标检测中的LetterBox操作

LetterBox类比理解&#xff1a;想象你要把一张任意形状的照片放进一个正方形的相框里&#xff0c;照片不能变形拉伸&#xff0c;所以你先等比例缩小照片&#xff0c;然后在空余的地方填上灰色背景。第1章 数学原理当我们有一个原始图像的尺寸为 19201080&#xff08;宽高&#…

Leetcode 3614. Process String with Special Operations II

Leetcode 3614. Process String with Special Operations II 1. 解题思路2. 代码实现 题目链接&#xff1a;3614. Process String with Special Operations II 1. 解题思路 这一题思路上是一个逆推的思路。 首先&#xff0c;我们顺序走一轮不难得到最终我们能够获得的字符串…

.NET ExpandoObject 技术原理解析

&#x1f31f; .NET ExpandoObject 技术原理解析 引用&#xff1a; .NET 剖析4.0上ExpandoObject动态扩展对象原理风潇潇人渺渺快意刀山中草 #mermaid-svg-RtpHctpdchPPN1Xo {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mer…

放苹果(信息学奥赛一本通-T1192)

【题目描述】把M个同样的苹果放在N个同样的盘子里&#xff0c;允许有的盘子空着不放&#xff0c;问共有多少种不同的分法&#xff1f;&#xff08;用K表示&#xff09;5&#xff0c;1&#xff0c;1和1&#xff0c;5&#xff0c;1 是同一种分法。【输入】第一行是测试数据的数目…

(懒人救星版)CNN_Kriging_NSGA2_Topsis(多模型融合典范)深度学习+SCI热点模型+多目标+熵权法 全网首例,完全原创,早用早发SCI

全网首例&#xff0c;完全原创&#xff0c;早用早发SCI&#xff08;多模型融合典范&#xff09;机器学习SCI热点模型多目标熵权法(懒人救星版)BP_Kriging_NSGA2_Topsis 改进克里金工作量大&#xff1a;多模型融合创新性&#xff1a;首次结合BP神经网络和克里金多目标利用 BP神…

LeetCode热题100【第一天】

第一题 两数之和 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。 你可以假设每种输入只会对应一个答案&#xff0c;并且你不能使用两次相同的元素。 你可以按任意顺序返回…

AI Linux 运维笔记

运维基本概念 IT运维是指通过专业技术手段&#xff0c;确保企业的IT系统和网络持续、安全、稳定运行&#xff0c;保障业务的连续性。运维涵盖计算机网络、应用系统、硬件环境和服务流程的综合管理。主要分为: 系统运维、数据库运维、自动化运维、容器运维、云计算运维、信创运维…

Redis性能基准测试

基准环境 机器&#xff1a;AWS EC2 c4.8xlarge&#xff08;同机部署 Redis Server 与 ReJSONBenchmark 工具&#xff0c;通过网络栈连接&#xff09;测试工具&#xff1a;ReJSONBenchmark&#xff08;Go 实现、可配置并发&#xff09;模式&#xff1a;非管线&#xff08;non-pi…

XML外部实体注入与修复方案

XML外部实体注入&#xff08;XXE&#xff09;是一种严重的安全漏洞&#xff0c;攻击者利用XML解析器处理外部实体的功能来读取服务器内部文件、执行远程请求&#xff08;SSRF&#xff09;、扫描内网端口或发起拒绝服务攻击。以下是详细解释和修复方案&#xff1a;XXE 攻击原理外…

解决高并发场景中的连接延迟:TCP 优化与队头阻塞问题剖析

你是否在高并发场景下遇到过这种情况&#xff1a;系统性能本来不错&#xff0c;但在请求量大增的时刻&#xff0c;连接延迟暴涨&#xff0c;响应时间直线飙升&#xff0c;甚至整个服务都变得不可用&#xff1f;当你打开监控时&#xff0c;CPU、内存、带宽都在正常范围内&#x…

Web学习笔记4

CSS概述1、CSS简介CSS&#xff0c;层叠样式表&#xff0c;是一种样式表语言&#xff0c;用以描述HTML的呈现内容的方式&#xff08;美化网页&#xff09;。CSS书写规则是&#xff1a;选择器{属性名&#xff1a;属性值}的键值对CSS有三种引入方式&#xff0c;分别为&#xff1a;…

Spring AI 初学者指南:从入门到实践与常用大模型介绍

作为 Java 开发者&#xff0c;当 AI 浪潮席卷而来时&#xff0c;如何在熟悉的 Spring 生态中快速拥抱大模型开发&#xff1f;Spring AI 的出现给出了答案。本文将从初学者视角出发&#xff0c;带你了解 Spring AI 的核心概念、使用方法&#xff0c;并介绍与之搭配的常用大模型&…

C#自定义控件

1。C#中控件和组件的区别&#xff1a; 一般组件派生于&#xff1a;Component类&#xff0c;所以从此类派生出的称之为组件。 一般用户控件派生于:Control类或UserControl类&#xff0c;所以从该类派生出的称之为用户控件。 他们之间的关系主要是&#xff1a;UserControl继承Con…

网络资产测绘工具全景解析:七大平台深度洞察

​一、资产测绘工具的核心价值​网络资产测绘&#xff08;Cyber Asset Intelligence&#xff09;技术通过主动扫描与被动分析&#xff1a;实时发现全球暴露的网络设备&#xff08;服务器、摄像头、IoT设备&#xff09;自动化构建资产指纹库&#xff08;操作系统/服务/框架版本&…

编程语言设计目的与侧重点全解析(主流语言深度总结)

编程语言的设计本质上是对计算逻辑的形式化表达与工程约束的平衡&#xff0c;不同语言因目标场景、时代需求和技术哲学的差异&#xff0c;形成了独特的设计范式。以下从系统级、应用级、脚本/动态、函数式、并发/安全等维度&#xff0c;选取10种最具代表性的编程语言&#xff0…

重学前端003 --- 响应式网页设计 CSS 颜色

文章目录文档声明head颜色模型div根据在这里 Freecodecamp 实践&#xff0c;记录笔记总结。 文档声明 在文档顶部添加 DOCTYPE html 声明 <!DOCTYPE html>head title 元素为搜索引擎提供了有关页面的额外信息。 它还通过以下两种方式显示 title 元素的内容&#xff1a…