Pytorch02:深度学习基础示例——猫狗识别

 一、第三方库介绍

库/模块功能
torch提供张量操作、自动求导、优化算法、神经网络模块等基础设施。
torchvision计算机视觉工具集,提供预训练模型、数据集、图像转换等功能。
datasets (torchvision)用于加载常见数据集(如 ImageNet、CIFAR-10、MNIST)。
transforms (torchvision)提供图像数据的预处理、数据增强操作(如大小调整、裁剪、转换为张量、归一化等)。
nn (torch)用于定义和构建神经网络,包含各类网络层、损失函数等。
optim (torch)提供优化算法(如 Adam、SGD、RMSprop)用于更新神经网络权重。
DataLoader (torch.utils.data)用于批量加载数据,支持多线程加载数据,按批次读取数据。
Image (PIL)用于图像处理(加载、裁剪、旋转、缩放、保存图像等)。
ResNet18_Weights (torchvision.models)提供 ResNet18 模型的预训练权重,可用于迁移学习。

二、训练数据集介绍

三、原理简介

        该代码使用PyTorch训练一个基于ResNet-18的猫狗分类模型。通过加载并处理数据、训练模型、调整最后输出层、使用Adam优化器进行反向传播,并在每个训练周期输出损失与准确率。训练完毕后,保存模型用于后续预测。

四、代码思路简介

  • 加载数据 → 使用 datasets.ImageFolder 加载猫狗数据集,并应用图像转换。
  • 构建模型 → 使用预训练的 ResNet-18 模型,修改输出层以适应2类分类。
  • 定义损失和优化器 → 使用交叉熵损失函数和 Adam 优化器。
  • 训练模型 → 遍历数据集,前向传播、计算损失、反向传播、更新模型参数。
  • 保存模型 → 训练完成后,保存模型权重。
  • 预测图片 → 加载已训练模型,输入图片进行预测并输出分类结果。

五、代码

场景:使用pytorch识别猫狗
猫的图片路径:F:\pycharm\AIDEMO\data\cat
狗的图片路径:F:\pycharm\AIDEMO\data\dog
需要判断的图片:F:\pycharm\AIDEMO\01.jpeg

import torch
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from PIL import Image
from torchvision.models import ResNet18_Weights# 定义transform类(视觉转换类,将图片格式转化为张量格式)
transform = transforms.Compose([transforms.Resize((128, 128)),  # 将图片缩放到统一大小transforms.ToTensor(),  # 转换为Tensor格式transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化处理
])def train_model(data_dir, num_epochs=10, batch_size=32, save_path='cat_dog_model.pth'):"""训练模型并保存。:param data_dir: 数据路径,包含cat和dog文件夹:param num_epochs: 训练周期,默认为10:param batch_size: 批次大小,默认为32:param save_path: 模型保存路径,默认为'cat_dog_model.pth'"""# 1. 加载训练数据train_data = datasets.ImageFolder(root=data_dir,  # 数据路径transform=transform)print(train_data.class_to_idx)  # 输出文件夹编号,例如这里输出{'cat': 0, 'dog': 1},表达0代表猫猫,1代表狗狗train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)# 2. 使用预训练的ResNet18模型model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.fc = nn.Linear(model.fc.in_features, 2)  # 修改输出层以适应2类分类(猫、狗)# 3. 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 4. 开始训练模型for epoch in range(num_epochs):model.train()  # 设置模型为训练模式running_loss = 0.0  # 初始化损失值correct = 0  # 模型预测准确数total = 0  # 模型预测总数for images, labels in train_loader:optimizer.zero_grad()  # 清除之前的梯度outputs = model(images)  # 前向传播,得出预测结果loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()# 计算准确率_, predicted = torch.max(outputs, 1)  # 获取预测结果total += labels.size(0)  # 累计总样本数correct += (predicted == labels).sum().item()  # 累计预测正确的样本数# 输出训练周期的损失和准确率accuracy = 100 * correct / total  # 计算准确率print(f'周期 [{epoch + 1}/{num_epochs}], 损失: {running_loss / len(train_loader):.4f}, 准确率: {accuracy:.2f}%')if accuracy == 100:  # 准确率达到100%就停止训练,避免过度拟合break# 保存训练模型torch.save(model.state_dict(), save_path)def predict_image(model_path, img_path):"""加载训练好的模型并进行图片预测。:param model_path: 训练好的模型路径:param img_path: 需要预测的图片路径:return: 预测结果(猫或狗)"""# 加载模型model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.fc = nn.Linear(model.fc.in_features, 2)model.load_state_dict(torch.load(model_path))model.eval()  # 设置模型为评估模式# 预测指定图片img = Image.open(img_path)img = transform(img).unsqueeze(0)  # 将图片处理成张量输出并增加batch维度# 模型预测with torch.no_grad():  # 不需要梯度计算,只是进行模型预测outputs = model(img)_, predicted = torch.max(outputs, 1)# 输出预测结果return "这是猫的图片" if predicted.item() == 0 else "这是狗的图片"if __name__ == "__main__":# 01 训练出模型(若已训练出准确度较高模型,可注释下面两句话,直接用训练完毕的模型预测)data_dir = 'F:/pycharm/AIDEMO/data'  # 数据路径train_model(data_dir, num_epochs=10, batch_size=32)# 02 预测指定图片img_path = 'F:/pycharm/AIDEMO/data/01.jpeg'  # 图片路径result = predict_image('cat_dog_model.pth', img_path)print(result)

六、输出结果展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92150.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92150.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring简单项目实战

项目路径 modelspackage com.qcby.demo1;import com.qcby.service.UserService; import com.qcby.service.UserServiceImpl;public class Dfactory {public UserService createUs(){System.out.println("实例化工厂的方式...");return new UserServiceImpl();} }pack…

ServBay for Windows 1.4.0 发布:新增MySQL、PostgreSQL等数据库自定义配置

各位 Windows 平台的开发者们, ServBay 始终致力于为您打造一个强大、高效且灵活的本地开发环境。距离上次更新仅过去短短一周,经过我们技术团队的快速开发,我们正式推出了 ServBay for Windows 1.4.0 版本。 专业开发者不仅需要一个能用的环…

python网络爬虫小项目(爬取评论)超级简单

python网络爬虫小项目(爬取评论)超级简单 学习python网络爬虫的完整路径: (第一章) python网络爬虫(第一章/共三章:网络爬虫库、robots.txt规则(防止犯法)、查看获取网页源代码)-…

本周大模型新动向:奖励引导、多模态代理、链式思考推理

点击蓝字关注我们AI TIME欢迎每一位AI爱好者的加入!01Iterative Distillation for Reward-Guided Fine-Tuning of Diffusion Models in Biomolecular Design本文提出了一种用于生物分子设计中奖励引导生成的扩散模型微调框架。扩散模型在建模复杂、高维数据分布方面…

JAVA+AI教程-第三天

我将由简入繁,由零基础到详细跟大家一起学习java---------------------------------------------------------------------01、程序流程控制:今日课程介绍02、程序流程控制:if分支结构if分支有三种形式,执行顺序就是先执行if&…

自定义命令行解释器shell

目录 一、模块框架图 二、实现目标 三、实现原理 四、全局变量 五、环境变量函数 六、初始化环境变量表函数 七、输出命令行提示符模块 八、提取命令输入模块 九、填充命令行参数表模块 十、检测并处理内建命令模块 十一、执行命令模块 十二、源码 一、模块框架图…

uniapp使用uni-ui怎么修改默认的css样式比如多选框及样式覆盖小程序/安卓/ios兼容问题

修改 uni-ui 多选框 (uni-data-checkbox) 的默认样式 在 uniapp 中使用 uni-ui 的 uni-data-checkbox 组件时,可以通过以下几种方式修改其默认样式: 方法一:使用深度选择器格式一:在页面的 style 部分使用深度选择器 >>>…

《Linux 环境下 Nginx 多站点综合实践:域名解析、访问控制与 HTTPS 加密部署》​

综合练习:请给openlab搭建web网站,网站需求: 1.基于域名www.openlab.com可以访问网站内容为 welcome to openlab!!, 2.给该公司创建三个子界面分别显示学生信息,教学资料和缴费网站,基于www.openlab.com/student 网站访…

网络基础1-11综合实验(eNSP):vlan/DHCP/Web/HTTP/动态PAT/静态NAT

注:在华为模拟器(eNSP)上做的实验其中,在内网实验:Vlan/DHCP/VWeb/HTTP,在外网实验:动态PAT/静态NAT一、拓扑结构1. 核心设备与连接设备接口连接对象VLAN/IP角色LSW2/LSW3Ethernet 0/0/1-2PC1/P…

Mac上安装Claude Code的步骤

以下是基于现有信息的简明安装指南,适用于macOS系统。请按照以下步骤操作: 前提条件 操作系统:macOS 10.15或更高版本。Node.js和npm:Claude Code基于Node.js,需安装Node.js 18和npm。请检查是否已安装: …

MybatisPlus-15.扩展功能-逻辑删除

一.逻辑删除配置逻辑删除的字段时,logic-delete-field字段配置的是逻辑删除的实体字段名。字段类型可以是boolean和integer。在java中默认是boolean类型。逻辑已删除值默认为1,而逻辑未删除值默认为0。当是1时代表已删除(1在数据库表中为true&#xff0c…

IDEA 同时修改某个区域内所有相同变量名

在 IntelliJ IDEA 中,同时修改某个区域内所有 相同变量名 的快捷键是: ✅ Shift F6(重命名变量) 但这个快捷键默认是 全局重命名,如果你想 仅修改某个方法或代码块内的变量名,可以这样做:&…

Telink BLE 低功耗学习

低功耗管理(Low Power Management)也可以称为功耗管理(Power Management),本⽂档中会简称为PM。Telink低功耗解惑我查阅多连接SDK开发手册时,低功耗管理章节看了两三遍也没太明白,有以下几个问题…

设备管理系统(MMS)如何在工厂MOM功能设计和系统落地

一、核心系统功能模块设备管理系统围绕设备全生命周期管理设计,涵盖基础数据管理、设备运维全流程管控及统计分析功能,具体如下:基础数据管理设备与备件台账:包含设备台账(设备编号、识别码、型号、生产日期等&#xf…

低空经济展 | 牧羽天航空携飞行重卡AT1300亮相2025深圳eVTOL展

为深入推动低空经济产业高质量发展,构建全球eVTOL(电动垂直起降飞行器)产业交流合作高端平台,2025深圳eVTOL展定于2025年9月23日至25日在深圳坪山燕子湖国际会展中心隆重举办。本届展会以“低空经济・eVTOL・航空应急救援・商载大…

CS231n-2017 Lecture4神经网络笔记

神经网络:我们之前的线性分类器可以接受输入,进而给出评分,这是一种线性变换,再此基础上,我们对这种线性变换结果进行非线性变换,并输入到下一层线性分类器中,这个过程就像是人类大脑神经的运作…

暑期算法训练.5

目录 20. 力扣 34.在排序数组中查找元素的第一个位置和最后一个位置 20.1 题目解析: 20.2 算法思路: 20.3 代码演示: ​编辑 20.4 总结反思: 21.力扣 69.x的平方根 21.1 题目解析: 21.2 算法思路:…

【HDLBits习题详解 2】Circuit - Sequential Logic(5)Finite State Machines 更新中...

1. Fsm1(Simple FSM 1 - asynchronous reset)状态机可分为两类:(1)Mealy状态机:输出由当前状态和输入共同决定。输入变化可能立即改变输出。(2)Moore状态机:输出仅由当前…

多级缓存(亿级流量缓存)

传统缓存方案问题 多级缓存方案 流程 1.客户端浏览器缓存页面静态资源; 2. 客户端请求到Nginx反向代理;[一级缓存_浏览器缓存] 3.Nginx反向代理将请求分发到Nginx集群(OpenResty); 4.先重Nginx集群OpenResty中获取Nginx本地缓存数据;[二级缓存_Nginx本地缓存] 5.若Nginx本地缓存…

浅谈Rust语言特性

如大家所了解的,Rust是一种由Mozilla开发的系统编程语言,专注于内存安全、并发性和高性能,旨在替代C/C等传统系统编程语言。Rust 有着非常优秀的特性,例如:可重用模块 内存安全和保证(安全的操作与不安全的…