【深度学习】PyTorch从0到1——手写你的第一个卷积神经网络模型,AI模型开发全过程实战

引言

本次准备建立一个卷积神经网络模型,用于区分鸟和飞机,并从CIFAR-10数据集中选出所有鸟和飞机作为本次的数据集。

以此为例,介绍一个神经网络模型从数据集准备、数据归一化处理、模型网络函数定义、模型训练、结果验证、模型文件保存,端到端的模型全生命周期,方便大家深入了解AI模型开发的全过程。

 

一、网络场景定义与数据集准备

1.1 数据集准备

本次我准备使用CIFAR10数据集,它是一个简单有趣的数据集,由60000张小RGB图片构成(32像素*32像素),每张图类别标签用1~10数字表示

%matplotlib inline
from matplotlib import pyplot as pltfrom torchvision import datasets
data_path = '/content/sample_data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)type(cifar10).__mro__

 

1.2 查看数据集类别示例

class_names = ['airplane', 'aotomobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(8, 3))
num_classes = 10
for i in range(num_classes):ax = fig.add_subplot(2, 5 ,1 + i, xticks=[], yticks=[])ax.set_title(class_names[i])img = next(img for img, label in cifar10 if label == i)plt.imshow(img)
plt.show()

 

1.2.1 输出单张图像类别及展示图片

img, label = cifar10[99]
img, label, class_names[label]

plt.imshow(img)
plt.show()

 

1.3 数据集Dataset变换

使用torchvision.transforms模块,将PIL图像变换为PyTorch张量,用于图像分类

1.3.1 将单张图像转换为张量,输出张量大小

from torchvision import transformsto_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape

1.3.2 将CIFAR10数据集转换为张量

tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())tensor_cifar10.__len__()

1.4 数据归一化

使用transforms.Compose()将图像连接起来,在数据加载器中直接进行数据归一化和数据增强操作

使用transforms.Normalize(),计算数据集中每个通道的平均值和标准差,使每个通道的均值为0,标准差为1

imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3)
imgs.shape

1.4.1 计算每个通道的平均值(mean)

imgs.view(3, -1).mean(dim=1)

1.4.2 计算每个通道的标准差(stdev)

imgs.view(3, -1).std(dim=1)

1.4.3 使用transforms.Normailze()对数据集归一化

使每个数据集的通道的均值为0,标准差为1

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))

二、使用nn.Module编写第一个识别鸟与飞机的网络模型

2.1 构建鸟与飞机的训练集和验证集

2.1.1 准备CIFAR10数据集

cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))
cifar10_val = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))

2.1.2 构建CIFAR2-数据集

label_map = {0:0, 2:1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])for img, label in cifar10if label in [0, 2]
]cifar2.__len__()

2.1.3 构建CIFAR2-验证集

cifar2_val = [(img, label_map[label])for img, label in cifar10_valif label in [0, 2]]cifar2_val.__len__()

2.1.4 准备批处理图像

img, _ = cifar2[0]plt.imshow(img.permute(1, 2, 0))
plt.show()

img
img.shape

2.2 编写第一个nn.Module子模块的网络定义

import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.act1 = nn.Tanh()self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.act2 = nn.Tanh()self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.act3 = nn.Tanh()self.fc2 = nn.Linear(32, 2)def forward(self, x):out = self.pool1(self.act1(self.conv1(x)))out = self.pool2(self.act2(self.conv2(out)))out = out.view(-1, 8 * 8 * 8)out = self.act3(self.fc1(out))out = self.fc2(out)return out

 

2.2.1 将网络模型实例化,并输出模型参数

model = Net()numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list

2.3 使用函数式API,优化nn.Module网络函数定义

import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.fc2 = nn.Linear(32, 2)def forward(self, x):out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)out = out.view(-1, 8 * 8 * 8)out = torch.tanh(self.fc1(out))out = self.fc2(out)return outmodel = Net()
model(img.unsqueeze(0))

2.4 定义网络模型的训练循环函数,并执行训练

import datetimedef training_loop(n_epochs, optimizer, model, loss_fn, train_loader):for epoch in range(1, n_epochs + 1):loss_train = 0.0for imgs, labels in train_loader:outputs = model(imgs)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()loss_train += loss.item()if epoch == 1 or epoch %10 == 0:print('{} Epoch {}, Training loss{}'.format(datetime.datetime.now(), epoch, loss_train / len(train_loader)))train_loader = torch.utils.data.DataLoader(cifar2, batch_size = 64, shuffle=True)model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()training_loop(n_epochs = 100,optimizer = optimizer,model = model,loss_fn = loss_fn,train_loader = train_loader,

 

2.4.1 训练结果(耗时7分钟)

2025-08-17 15:13:20.123706 Epoch 1, Training loss0.5672952472024663
2025-08-17 15:14:01.667640 Epoch 10, Training loss0.32902660861516453
2025-08-17 15:14:47.187795 Epoch 20, Training loss0.2960508146863075
2025-08-17 15:15:33.119990 Epoch 30, Training loss0.26820498961172284
2025-08-17 15:16:19.303661 Epoch 40, Training loss0.24607981879050564
2025-08-17 15:17:04.858228 Epoch 50, Training loss0.22783752284042394
2025-08-17 15:17:50.712569 Epoch 60, Training loss0.2095268357806145
2025-08-17 15:18:36.846523 Epoch 70, Training loss0.19460647420328894
2025-08-17 15:19:22.404563 Epoch 80, Training loss0.18098321051639357
2025-08-17 15:20:08.067236 Epoch 90, Training loss0.16757476806735536
2025-08-17 15:20:54.041604 Epoch 100, Training loss0.15512346253273593

2.5 测量准确率(使用验证集)

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)def validate(model, train_loader, val_loader):for name, loader in [("train", train_loader), ("val", val_loader)]:correct = 0total = 0with torch.no_grad():for imgs, labels in loader:outputs = model(imgs)_, predicted = torch.max(outputs, dim=1)total += labels.shape[0]correct += int((predicted == labels).sum())print("Accuracy {}: {:.2f}".format(name, correct/total))validate(model, train_loader, val_loader)

2.6 保存并加载我们的模型

2.6.1 保存模型

torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

2.6.2 模型pt文件生成

包含模型的所有参数,即2个卷积模块和2个线性模块的权重和偏置

2.6.3 加载参数到模型实例

loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt'))

三、小结

至此,我们完成一个卷积神经网络模型birds_vs_airplanes的构建,可用于图像分类识别,区分图片是鸟还是飞机,准确性高达94!

我们从数据集准备、数据集准备、数据归一化处理、模型网络函数定义、模型训练、结果验证、模型文件保存,并将模型参数加载到另一个新模型实例中,端到端完整串联一个神经网络模型全生命周期的过程,加深对AI模型开发的理解,这是个经典案例,快来试试吧~

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/96081.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/96081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云计算核心技术之容器技术

一、容器技术 1.1、为什么需要容器 在使用虚拟化一段时间后,发现它存在一些问题:不同的用户,有时候只是希望运行各自的一些简单程序,跑一个小进程。为了不相互影响,就要建立虚拟机。如果建虚拟机,显然浪费就…

微信小程序通过uni.chooseLocation打开地图选择位置,相关设置及可能出现的问题

前言 uni.chooseLocation打开地图选择位置,看官方文档介绍的比较简单,但是需要注意的细节不少,如果没有注意可能就无法使用该API或者报错,下面就把详细的配置方法做一下介绍。 一、勾选位置接口 ①在uniapp项目根目录找到manif…

从财务整合到患者管理:德国医疗集团 Asklepios完成 SAP S/4HANA 全链条升级路径

目录 挑战 解决方案 详细信息 Asklepios成立于1985年,目前拥有约170家医疗机构,是德国大型私营诊所运营商。Asklepios是希腊和罗马神话中的医神。 挑战 Asklepios希望进一步扩大其作为数字医疗保健集团的地位。2020年9月,该公司与SNP合作…

高频PCB厂家及工艺能力分析

一、技术领先型厂商(适合高复杂度、高可靠性设计)这类厂商在高频材料处理、超精密加工和信号完整性控制方面具备深厚积累,尤其适合军工、卫星通信、医疗设备等严苛场景:深南电路:在超高层板和射频PCB领域是行业标杆&am…

AJAX 与 ASP 的融合:技术深度解析与应用

AJAX 与 ASP 的融合:技术深度解析与应用 引言 随着互联网技术的不断发展,AJAX(Asynchronous JavaScript and XML)和ASP(Active Server Pages)技术逐渐成为构建动态网页和应用程序的重要工具。本文将深入探讨AJAX与ASP的融合,分析其原理、应用场景以及在实际开发中的优…

MuMu模拟器Pro Mac 安卓手机平板模拟器(Mac中文)

原文地址:MuMu模拟器Pro Mac 安卓手机平板模拟器 MuMu模拟器 Pro mac版,是一款MuMuPlayer安卓模拟器,可以畅快运行安卓游戏和应用。 MuMu模拟器Pro搭载安卓12操作系统,极致释放设备性能,最高支持240帧画面效果&#…

Oracle维护指南

Part 1 Oracle 基础与架构#### **1.1 概述** - **Oracle 数据库版本历史与特性对比** - **版本演进**: - Oracle 8i(1999):支持 Internet 应用,引入 Java 虚拟机(JVM)。 - Oracle 9i&#…

如何为PDF文件批量添加骑缝章?

骑缝章跨越多页文件的边缘加盖,一旦文件被替换其中某一页或顺序被打乱,印章就无法对齐,能立刻发现异常。这有效保障了文件的完整性和真实性。它是纯净免费,不带广告,专治各类PDF盖章需求。用法极简:文件直接…

组合时代的 TOGAF®:为模块化企业重新思考架构

随着企业努力追求敏捷性和创新性,组合性正逐渐成为一项基础性的设计原则。组合思维改变了企业交付能力的方式 —— 更倾向于采用模块化、独立的组件,这些组件可以快速组装和重组。本文探讨了长期以来作为企业架构框架的TOGAF标准如何演进以支持组合架构。…

电子元器件-电阻终篇:基本原理,电阻分类及特点,参数/手册详解,电阻作用及应用场景,电阻选型及实战案例

目录 一、基本原理 1.1 介绍 1.2 计算公式​编辑 1.3 单位 1.4 标称值 二、分类及特点 2.1电阻分类及特点介绍 2.2常用电阻器件详细介绍 三、参数/数据手册解读 3.1 阻值 3.2 封装&功率 3.3 精度 3.5 额定电压 3.6 温度系数(TCR) 3.7 扩展 四、作用与使用场…

【软件测试】电商购物项目-各个测试点整理(六)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 1、优惠券测试点 …

心路历程-启动流程的概念

我们之前已经安装过系统,其实兴奋的内心已经无以言表; 记得刚开始的那份喜悦是没办法演说的;可是高兴之余,好像突然又心情EMO了; 为何呢?因为系统装完了,你也不知道能够干什么; 所以…

Kubernetes Ingress实战:从环境搭建到应用案例

目录 一、概述 版本对比图 二、 Ingress应用案例 2.1 环境准备 2.2 验证-NodePort模式 设置Http代理 2.3 验证-LoadBalancer模式 修改ARP模式,启用严格ARP模式 搭建metallb支持LoadBalancer 普通的service测试 ingress访问测试: 一、概述 Ser…

项目发布上线清单

说明:博主想整理一份项目发布上线的清单,在每次发布上线前,对照清单一一核对,避免遗漏(往事不堪回首),欢迎大家补充。 前端是否有与后端协同发布的接口? 如果有,先发前端…

HTB Information Gathering - Web Edition最后的测验

因为它没有DNS解析,,所以不要尝试去使用dns枚举所有枚举出来的子域,马上修改hosts文件,与ip和域名填好,因为它不依赖dns通过vhost子域爆破 爬虫登场 w*****.inlanefreight.htb:32508爬到之后不要去理会那个api,除了填答案,,,其他任何用处都没有,不要浪费时间后面就不能剧透了,可…

IDEA、Pycharm、DataGrip等激活破解冲突问题解决方案之一

Jetbranis旗下的软件破解冲突问题解决方案之一,不一定适用所有破解包 问题:在使用Pycharm破解包破解该软件后,同样是jetbranis旗下软件的Datagrip却失去了之前破解的效果,需要重新破解,重新成功破解datagrip后&#xf…

使用 uv管理 Python 虚拟环境:比conda更快、更轻量的现代方案

文章目录什么是 uv?安装 uv在线安装(推荐)Windows 系统Linux / macOS 系统离线安装步骤 1:获取二进制包步骤 2:解压并移动到可执行路径步骤 3:设置环境变量验证安装创建并激活虚拟环境创建虚拟环境输出示例…

课堂记忆项目开发日志

课堂记忆项目开发日志 日期: 2025年8月18日 1. 基础实现 项目目标: 创建一个动态、美观的“课堂记忆”页面,展示教师信息、教学成果、学生反馈、未来计划、教学成就和教学金句。 实现交互功能,包括按钮点击展开内容、图片点击弹出详细信息、图表展示数据。 技术栈: HTML5 C…

蓝桥杯算法之搜索章 - 7

大家好,不同的时间,相同的地点!又和大家见面了,接下来我将带来多源BFS的内容 通过多源BFS的学习,大家将对BFS理解更加深入! lets go! 前言 通过前面内容的学习,大家肯定已经对于BFS有了一定理解…

onRequestHide at ORIGIN_CLIENT reason HIDE_SOFT_INPUT fromUser false

这个错误日志 onRequestHide at ORIGIN_CLIENT reason HIDE_SOFT_INPUT fromUser false 通常出现在 Android 平台的 WebView 或混合应用(如 Cordova/Capacitor)中,与软键盘(Soft Input)的隐藏行为有关。以下是可能的原…