Day 40训练

Day 40 训练

  • PyTorch 图像数据训练与测试的规范写法
    • 单通道图像的规范训练流程
      • 数据预处理与加载
      • 模型定义
      • 训练与测试函数封装
      • 模型训练执行
    • 彩色图像的扩展应用
      • 数据预处理调整
      • 模型结构调整
    • 关键要点总结


知识点回顾:

彩色和灰度图片测试和训练的规范写法:封装在函数中
展平操作:除第一个维度batchsize外全部展平
dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

PyTorch 图像数据训练与测试的规范写法

在深度学习项目中,规范的代码结构能极大提升开发效率与代码可维护性。本文将基于 PyTorch 框架,详细讲解图像数据训练和测试的规范写法,从单通道图像到彩色图像,助你构建高效、清晰的模型训练流程。

单通道图像的规范训练流程

数据预处理与加载

我们以 MNIST 手写数字数据集为例,其为单通道灰度图像。数据预处理是模型训练的起点,我们利用 torchvision.transforms 对图像进行转换:

transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # 使用 MNIST 数据集的均值和标准差进行标准化
])

接着加载数据集并创建数据加载器:

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

模型定义

针对 MNIST 图像尺寸(28×28),定义一个多层感知机(MLP)模型:

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将 28x28 图像展平为 784 维向量self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return x

训练与测试函数封装

为提升代码复用性与可读性,我们将训练和测试逻辑封装为函数:

def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单 Batch 损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)return epoch_test_accdef test(model, test_loader, criterion, device):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

模型训练执行

设置训练轮次并启动训练:

epochs = 2
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

彩色图像的扩展应用

对于彩色图像(如 CIFAR-10 数据集),处理流程与单通道图像类似,主要差异在于:

数据预处理调整

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 适应彩色图像的标准化
])

模型结构调整

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将 3x32x32 图像展平为 3072 维向量self.layer1 = nn.Linear(3072, 512)self.relu1 = nn.ReLU()self.dropout1 = nn.Dropout(0.2)self.layer2 = nn.Linear(512, 256)self.relu2 = nn.ReLU()self.dropout2 = nn.Dropout(0.2)self.layer3 = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu1(x)x = self.dropout1(x)x = self.layer2(x)x = self.relu2(x)x = self.dropout2(x)x = self.layer3(x)return x

关键要点总结

  1. 数据处理规范化 :利用 DataLoaderDataset 对数据进行分批次处理,提高数据加载效率。
  2. 模型结构清晰化 :明确展平操作在图像任务中的应用,彩色图像需考虑通道维度。
  3. 训练测试函数封装 :将训练和测试逻辑封装为函数,便于参数调整与复用,为多模型对比奠定基础。
  4. 迭代损失记录 :记录每个迭代的损失,绘制损失曲线辅助训练过程分析。

通过遵循上述规范写法,无论是单通道还是彩色图像数据,都能高效地完成模型训练与测试任务,在实际项目中可根据需求灵活扩展与优化。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/83661.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

杰理可视化SDK--系统死机异常调试

杰理可视化SDK--系统死机异常调试 系统异常原因杰理SDK异常调试准备工作杰理SDK系统异常定位异常代码示例1异常代码示例2 在使用杰理可视化SDK进行软件开发时,往往会遇到一些系统异常问题,系统异常是指芯片在运行代码时,由于软件或硬件状态出…

图简记。。

模仿&#xff1a; algorithm-journey/src/class059/Code01_CreateGraph.java at main algorithmzuo/algorithm-journey Code01_CreateGraph C语言&#xff1a; #include <stdio.h> #include <stdlib.h> #include <string.h>#define MAXN 11 #define MAX…

Linux 常用命令与 Shell 简介

文章目录 **Linux 常用命令与 Shell 简介****Shell 简介****什么是 Shell&#xff1f;****Shell 的工作原理****常见 Shell 类型****命令行基础****Tab 补全与通配符** **Linux 常用命令****1. 入门必备命令****1.1 寻求帮助 - man 命令****1.2 用户间切换 - su 命令****1.3 特…

基于51单片机的超声波智能避障小车仿真

目录 具体实现功能 设计介绍 资料内容 全部内容 资料获取 具体实现功能 &#xff08;1&#xff09;超声波实时测量小车与障碍物间的距离&#xff0c;并用LCD1602显示。 &#xff08;2&#xff09;当测得的距离超过50时&#xff0c;前进电机转动&#xff08;模拟后轮&#…

AIGC工具平台-GPT-SoVITS-v4-TTS音频推理克隆

声音克隆与语音合成的结合&#xff0c;是近年来生成式AI在多模态方向上的重要落地场景之一。随着预训练模型能力的增强&#xff0c;结合语音识别、音素映射与TTS合成的端到端系统成为初学者可以上手实践的全流程方案。 围绕 GPT-SoVITS-v4-TTS 模块&#xff0c;介绍了其在整合…

Android7 Input(十)View 处理Input事件pipeline

概述: 本文主要描述View对InputEvent事件pipeline处理过程。 本文涉及的源码路径 frameworks/base/core/java/android/view/ViewRootImpl.java InputEvent事件处理 View处理input事件是调用doProcessInputEvents方法&#xff0c;如下所示: void doProcessInputEvents() {//…

Neo4j 完全指南:从入门到精通

第1章&#xff1a;Neo4j简介与图数据库基础 1.1 图数据库概述 传统关系型数据库与图数据库的对比图数据库的核心优势图数据库的应用场景 1.2 Neo4j的发展历史 Neo4j的起源与演进Neo4j的版本迭代Neo4j在图数据库领域的地位 1.3 图数据库的基本概念 节点(Node)与关系(Relat…

网心云 OEC/OECT 笔记(1) 拆机刷入Armbian固件

目录 网心云 OEC/OECT 笔记(1) 拆机刷入Armbian固件网心云 OEC/OECT 笔记(2) 运行RKNN程序 外观 内部 PCB正面 PCB背面 PCB背面 RK3566 1Gbps PHY 配置 OEC 和 OECT(OEC-turbo) 都是基于瑞芯微 RK3566/RK3568 的网络盒子, 没有HDMI输入输出. 硬件上 OEC 和 OECT…

摄像机ISP处理流程

1.Bayer&#xff1a;生成raw图&#xff0c;添加色彩数据&#xff08;RGB&#xff09;&#xff0c;一般会将G的占比设置为R和B的和&#xff0c;实例&#xff1a; 2.黑电平矫正&#xff1a;减去暗电流造成的误差&#xff1b; 3.镜头矫正&#xff1a;对四周的亮度进行矫正&#x…

【后端架构师的发展路线】

后端架构师的发展路线是从基础开发到技术领导的系统性进阶过程&#xff0c;需融合技术深度、架构思维和业务洞察力。以下是基于行业实践的职业发展路径和关键能力模型&#xff1a; 一、职业发展阶梯‌ 初级工程师&#xff08;1-3年&#xff09;‌ 核心能力‌&#xff1a;掌…

Unity VR/MR开发-VR开发与传统3D开发的差异

视频讲解链接&#xff1a;【XR马斯维】VR/MR开发与传统3D开发的差异【UnityVR/MR开发教程--入门】_哔哩哔哩_bilibili

RabbitMQ如何保证消息可靠性

RabbitMQ是一个流行的开源消息代理&#xff0c;它提供了可靠的消息传递机制&#xff0c;广泛应用于分布式系统和微服务架构中。在现代应用中&#xff0c;确保消息的可靠性至关重要&#xff0c;以防止消息丢失和重复处理。本文将详细探讨RabbitMQ如何通过多种机制保证消息的可靠…

批量图片管理软件介绍

软件介绍 本文介绍一款功能全面的图片处理软件 - FastStone Image Viewer。 软件功能概述 FastStone Image Viewer不仅支持图片查看&#xff0c;还具备编辑、批量重命名和批量转换等多种实用功能。 用户授权说明 该软件对个人用户完全免费&#xff0c;企业用户只需输入用户…

Playwright 测试框架 - Java

🚀【Playwright + Java 实战教程】从零到一掌握自动化测试利器! 🔧 本文专为 Java 开发者量身打造,通过详尽示例带你快速掌握 Playwright 自动化测试。涵盖基础操作、表单交互、测试框架集成、高阶功能及常见实战技巧,适用于企业 UI 测试与 CI/CD 场景。 🛠️ 一、环境…

nvidia系列教程-Usb otg模式修改为host模式

目录 前言 一、了解 USB OTG 模式与 Host 模式 二、host模式切换 总结 前言 在 NVIDIA 设备的使用过程中,有时我们需要将 USB OTG(On-The-Go)模式切换为 Host 模式,以满足连接外部设备(如 U 盘、鼠标、键盘等)的需求。本文将详细介绍如何在 NVIDIA 设备上进行这一模式…

二叉树-104.二叉树的最大深度-力扣(LeetCode)

一、题目解析 这里需要注意根节点的深度是1&#xff0c;也就是说计算深度的是从1开始计算的 二、算法原理 解法1&#xff1a;广度搜索&#xff0c;使用队列 解法2&#xff1a;深度搜索&#xff0c;使用递归 当计算出左子树的深度l&#xff0c;与右子树的深度r时&#xff0c;…

Calendar类日期设置进位问题

背景 报表需求&#xff0c;需要传递每组数据中最小的日期&#xff0c;后台根据传递的最小日期&#xff0c;向前取参数传递的月份的上个月为结束时间的近五个月数据 例&#xff1a;参数传:2025/02&#xff0c;则需返回2025/01, 2024/12, 2024/11, 2024/10, 2024/09这五个年月数据…

编程笔记---问题小计

编程笔记 qml ProgressBar 为什么valuemodel.progress / 100 在QML中&#xff0c;ProgressBar的value属性用于表示进度条的当前进度值&#xff0c;其范围通常为0到1&#xff08;或0%到100%&#xff09;。当使用model.progress / 100来设置value时&#xff0c;这样做的原因是为…

【STL】函数对象+常用算法

文章目录 STL- 函数对象函数对象函数对象使用 谓词一元谓词二元谓词内建函数对象算术仿函数关系仿函数 STL- 常用算法常用遍历算法for_eachtransform 常用查找算法findfind_ifadjacent_findbinary_searchcountcount_if 常用排序算法sortrandom_shufflemergereverse 常用拷贝和替…

[JVM] JVM内存调优

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…