【python深度学习】Day 40 训练和测试的规范写法

知识点回顾:
  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 创建数据加载器
batch_size = 64  # 每批处理64个样本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 设置为训练模式# 新增:记录每个 iteration 的损失all_iter_losses = []  # 存储所有 batch 的损失iter_indices = []     # 存储 iteration 序号(从1开始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向传播loss = criterion(output, target)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数# 记录当前 iteration 的损失(注意:这里直接使用单 batch 损失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序号从1开始# 统计准确率和损失(原逻辑保留,用于 epoch 级统计)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100个批次打印一次训练信息(可选:同时打印单 batch 损失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')# 原 epoch 级逻辑(测试、打印 epoch 结果)不变epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 绘制所有 iteration 的损失曲线plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 级曲线(可选)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最终测试准确率# 6. 测试模型
def test(model, test_loader, criterion, device):model.eval()  # 设置为评估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回损失和准确率# 7.绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 执行训练和测试(设置 epochs=2 验证效果)
epochs = 2  
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/85115.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亡羊补牢与持续改进 - SRE 的安全日志、审计与事件响应

亡羊补牢与持续改进 - SRE 的安全日志、审计与事件响应 如果说我们之前讨论的安全措施(如 IAM、网络策略、密钥管理、漏洞补丁)是为我们的“数字城堡”修筑坚固的城墙、设置精密的门锁、定期检查和修补潜在的裂缝,那么安全日志就像是遍布城堡内外的监控摄像头和出入登记簿,…

CppCon 2014 学习第2天:Using Web Services in C++

概述 这是一个会议或演讲的概述内容,主要介绍一个关于C Rest SDK的分享,翻译和理解如下: 翻译 概述 先介绍什么是典型的Web服务结构和它的特征讲讲调用这些Web服务的几种方式重点介绍自己团队开发的一个C库(C Rest SDK&#xf…

【OpenHarmony】【交叉编译】使用gn在Linux编译3568a上运行的可执行程序

linux下编译arm64可执行程序 一.gn ninja安装二.交叉编译工具链安装1.arm交叉编译工具2.安装arm64编译器 三. gn文件添加arm及arm64工具链四.编译验证 本文以gn nijia安装中demo为例,将其编译为在arm64(rk_3568_a开发板)环境下可运行的程序 一.gn ninja安装 安装g…

【开发心得】AstrBot对接飞书失败的问题探究

飞书与AstrBot的集成使用中,偶尔出现连接不稳定的现象。尽管不影响核心功能,但为深入探究技术细节并推动后续优化,需系统性记录该问题。先从底层通信机制入手,分析连接建立的逻辑与数据交互流程。基于实际现象,明确问题发生的具体场景和表现特征,进而梳理潜在影响因素,为…

Spring Boot 3.5.0中文文档上线

Spring Boot 3.5.0 中文文档翻译完成,需要的可收藏 传送门:Spring Boot 3.5.0 中文文档

7.atlas安装

1.服务器规划 软件版本参考: https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-release-2.2?hlzh-cn 由于hive3.1.3不完全支持jdk8,所以将hive的版本调整成4.0.1。这个版本没有验证过,需要读者自己抉择。 所有的软件都安装再/op…

c# 获取电脑 分辨率 及 DPI 设置

using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Runtime.InteropServices;/// <summary> /// 这个可以 /// </summary> class Program {static void Main(){//设置DPI感知try{SetProcessDpiAwareness(…

LangChain表达式(LCEL)实操案例1

案例1&#xff1a;写一篇短文&#xff0c;然后对这篇短文进行打分 from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableWithMessageHist…

OleDbParameter.Value 与 DataTable.Rows.Item.Value 的性能对比

OleDbParameter.Value 与 DataTable.Rows.Item.Value 的性能对比 您提到的两种赋值操作属于不同场景&#xff0c;它们的性能和稳定性取决于具体使用方式。下面从几个维度进行分析&#xff1a; 1. 操作本质对比 &#xff08;1&#xff09;OleDbParameter.Value 用途&#xf…

【Opencv+Yolo】Day2_图像处理

目录 一、图像梯度计算 图像梯度-sobal算子&#xff1a; Scharr&#xff1a;权重变化更大&#xff08;线条更加丰富&#xff0c;比Sobel更加细致捕捉更多梯度信息&#xff09; Laplacian算子&#xff1a;对噪音点敏感&#xff08;可以和其他一起结合使用&#xff09; 二、边…

STM32通过rt_hw_hard_fault_exception中的LR寄存器追溯程序问题​

1. 问题现象 程序运行导致rt_hw_hard_fault_exception 如图 显示错误相关代码 struct exception_stack_frame {uint32_t r0;uint32_t r1;uint32_t r2;uint32_t r3;uint32_t r12; uint32_t lr; // 链接寄存器 (LR)uint32_t pc; // 程序计数器 (PC)uint32_t psr; // 程序状态…

Mac安装配置InfluxDB,InfluxDB快速入门,Java集成InfluxDB

1. 与MySQL的比较 InfluxDBMySQL解释BucketDatabase数据库MeasurementTable表TagIndexed Column索引列FieldColumn普通列PointRow每行数据 2. 安装FluxDB brew update默认安装 2.x的版本 brew install influxdb查看influxdb版本 influxd version # InfluxDB 2.7.11 (git: …

【spring】spring中的retry重试机制; resilience4j熔断限流教程;springboot整合retry+resilience4j教程

在调用三方接口时&#xff0c;我们一般要考虑接口调用失败的处理&#xff0c;可以通过spring提供的retry来实现&#xff1b;如果重试几次都失败了&#xff0c;可能就要考虑降级补偿了&#xff1b; 有时我们也可能要考虑熔断&#xff0c;在微服务中可能会使用sentinel来做熔断&a…

(21)量子计算对密码学的影响

文章目录 2️⃣1️⃣ 量子计算对密码学的影响 &#x1f30c;&#x1f50d; TL;DR&#x1f680; 量子计算&#xff1a;密码学的终结者&#xff1f;⚡ 量子计算的破坏力 &#x1f510; Java密码学体系面临的量子威胁&#x1f525; 受影响最严重的Java安全组件 &#x1f6e1;️ 后…

经营分析会,财务该怎么做?

目录 一、业绩洞察&#xff1a;从「现象描述」到「因果分析」 1.分层拆解 2.关联验证 3.根因追溯 二、预算管理&#xff1a;从「刚性控制」到「动态平衡」 1.分类管控 2.滚动校准 3.价值评估 三、客户与市场&#xff1a;从「交易记录」到「价值评估」 1.价值分层 2.…

进阶智能体实战九、图文需求分析助手(ChatGpt多模态版)(帮你生成 模块划分+页面+表设计、状态机、工作流、ER模型)

🧠 基于 ChatGPT 多模态大模型的需求文档分析助手 本文将介绍如何利用 OpenAI 的 GPT-4o 多模态能力,构建一个智能的需求文档分析助手,自动提取功能模块、菜单设计、字段设计、状态机、流程图和 ER 模型等关键内容。 一、🔧 环境准备 在开始之前,请确保您已经完成了基础…

图书管理系统的设计与实现

湖南软件职业技术大学 本科毕业设计(论文) 设计(论文)题目 图书管理系统的设计与实现 学生姓名 学生学号 所在学院 专业班级 毕业设计(论文)真实性承诺及声明 学生对毕业设计(论文)真实性承诺 本人郑重声明:所提交的毕业设计(论文)作品是本人在指导教师的指导下,独…

直线模组在手术机器人中有哪些技术挑战?

手术机器人在现代医疗领域发挥着越来越重要的作用&#xff0c;直线模组作为其关键部件&#xff0c;对手术机器人的性能有着至关重要的影响。然而&#xff0c;在手术机器人中使用直线模组面临着诸多技术挑战&#xff0c;具体如下&#xff1a; 1、‌高精度要求‌&#xff1a;手术…

技术-工程-管用养修保-智能硬件-智能软件五维黄金序位模型

融智学工程技术体系&#xff1a;五维协同架构 基于邹晓辉教授的框架&#xff0c;工程技术体系重构为&#xff1a;技术-工程-管用养修保-智能硬件-智能软件五维黄金序位模型&#xff1a; math \mathbb{E}_{\text{技}} \underbrace{\prod_{\text{Dis}} \text{TechnoCore}}_{\…

InnoDB引擎逻辑存储结构及架构

简化理解版 想象 InnoDB 是一个高效运转的仓库&#xff1a; 核心内存区 (大脑 & 高速缓存 - 干活超快的地方) 缓冲池 Buffer Pool (最最核心&#xff01;)&#xff1a; 作用&#xff1a; 相当于仓库的“高频货架”。把最常用的数据&#xff08;表数据、索引&#xff09;从…