python训练营day40

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 3. 创建数据加载器
batch_size = 64  # 每批处理64个样本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()  # 设置为训练模式# 新增:记录每个 iteration 的损失all_iter_losses = []  # 存储所有 batch 的损失iter_indices = []     # 存储 iteration 序号(从1开始)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)optimizer.zero_grad()  # 梯度清零output = model(data)  # 前向传播loss = criterion(output, target)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数# 记录当前 iteration 的损失(注意:这里直接使用单 batch 损失,而非累加平均)iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序号从1开始# 统计准确率和损失(原逻辑保留,用于 epoch 级统计)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100个批次打印一次训练信息(可选:同时打印单 batch 损失)if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')# 原 epoch 级逻辑(测试、打印 epoch 结果)不变epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 绘制所有 iteration 的损失曲线plot_iter_losses(all_iter_losses, iter_indices)# 保留原 epoch 级曲线(可选)# plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)return epoch_test_acc  # 返回最终测试准确率# 6. 测试模型
def test(model, test_loader, criterion, device):model.eval()  # 设置为评估模式test_loss = 0correct = 0total = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy  # 返回损失和准确率# 7.绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 执行训练和测试(设置 epochs=2 验证效果)
epochs = 2  
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/82067.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Baklib企业CMS全流程管控与智能协作

企业CMS全流程管控方案解析 现代企业内容管理中,全流程管控的实现依赖于对生产、审核、发布及迭代环节的系统化整合。通过动态发布引擎与元数据智能标记技术,系统可自动匹配内容与目标场景,实现标准化模板驱动的快速部署。针对多分支机构的复…

Qt程序添加调试输出窗口:CONFIG += console

目录 1.背景 2.解决方案 3.原理详解 4.控制台窗口的行为 5.条件编译(仅调试模式显示控制台) 6.替代方案 7.总结 1.背景 在Qt程序开发中,开发者经常遇到这样的困扰: 开发机上程序运行正常 发布到其他机器后程序无法启动 …

《江西棒球资讯》棒球运动发展·棒球1号位

联赛体系结构 | League Structure MLB模式 MLB采用分层体系(大联盟、小联盟),强调梯队建设和长期发展。 MLB operates a tiered system (Major League, Minor League) with a focus on talent pipelines and long-term development. 中国现…

Python爬虫实战:研究Tornado框架相关技术

1. 引言 1.1 研究背景与意义 网络爬虫作为一种自动获取互联网信息的程序,在信息检索、数据挖掘、舆情分析等领域有着广泛的应用。随着互联网数据量的爆炸式增长,对爬虫的性能和效率提出了更高的要求。传统的同步爬虫在处理大量 URL 时效率低下,而异步爬虫可以显著提高并发…

Vue-列表过滤排序

列表过滤 基础环境 数据 persons: [{ id: "001", name: "刘德华", age: 19 },{ id: "002", name: "马德华", age: 20 },{ id: "003", name: "李小龙", age: 21 },{ id: "004", name: "释小龙&q…

JDK21深度解密 Day 9:响应式编程模型重构

【JDK21深度解密 Day 9】响应式编程模型重构 引言:从Reactor到虚拟线程的范式转变 在JDK21中,虚拟线程的引入彻底改变了传统的异步编程模型。作为"JDK21深度解密"系列的第91天,我们将聚焦于响应式编程模型重构这一关键主题。通过…

UE5打包项目设置Project Settings(打包widows exe安装包)

UE5打包项目Project Settings Edit-Project Settings- Packaging-Ini Section Denylist-Advanced 1:打包 2:高级设置 3:勾选创建压缩包 4:添加要打包地图Map的数量 5:选择要打包的地图Maps 6:Project-Bui…

Fastapi 学习使用

Fastapi 学习使用 Fastapi 可以用来快速搭建 Web 应用来进行接口的搭建。 参考文章:https://blog.csdn.net/liudadaxuexi/article/details/141062582 参考文章:https://blog.csdn.net/jcgeneral/article/details/146505880 参考文章:http…

java helloWord java程序运行机制 用idea创建一个java项目 标识符 关键字 数据类型 字节

HelloWord public class Hello{public static void main(String[] args) {System.out.print("Hello,World!");} }java程序运行机制 用idea创建一个java项目 建立一个空项目 新建一个module 注释 标识符 关键字 标识符注意点 数据类型 public class Demo02 {public st…

随机响应噪声-极大似然估计

一、核心原因:噪声机制的数学可逆性 在随机响应机制(Randomized Response)中使用极大似然估计(Maximum Likelihood Estimation, MLE)是为了从扰动后的噪声数据中无偏地还原原始数据的统计特性。随机响应通过已知概率的…

SMT贴片机工艺优化与效率提升策略

内容概要 现代电子制造领域中,SMT贴片机作为核心生产设备,其工艺优化与效率提升直接影响企业竞争力。本文聚焦设备参数校准、吸嘴选型匹配、SPI检测技术三大技术模块,结合生产流程重构与设备维护周期优化两大管理维度,形成系统性…

AI提示工程(Prompt Engineering)高级技巧详解

AI提示工程(Prompt Engineering)高级技巧详解 文章目录 一、基础设计原则二、高级提示策略三、输出控制技术四、工程化实践五、专业框架应用提示工程是与大型语言模型(LLM)高效交互的关键技术,精心设计的提示可以显著提升模型输出的质量和相关性。以下是经过验证的详细提示工…

光电设计大赛智能车激光对抗方案分享:低成本高效备赛攻略

一、赛题核心难点与备赛痛点解析 全国大学生光电设计竞赛的 “智能车激光对抗” 赛题,要求参赛队伍设计具备激光对抗功能的智能小车,需实现光电避障、目标识别、轨迹规划及激光精准打击等核心功能。从历年参赛情况看,选手普遍面临三大挑战&a…

【KWDB 创作者计划】_再热垃圾发电汽轮机仿真与监控系统:KaiwuDB 批量插入10万条数据性能优化实践

再热垃圾发电汽轮机仿真与监控系统:KaiwuDB 批量插入10万条数据性能优化实践 我是一台N25-3.82/390型汽轮机,心脏在5500转/分的轰鸣中跳动。垃圾焚烧炉是我的胃,将人类遗弃的残渣转化为金色蒸汽,沿管道涌入我的胸腔。 清晨&#x…

day23-计算机网络-1

1. 网络简介 1.1. 网络介质 网线:cat5,cat5e 六类网线,七类网线,芭蕾网线光纤:wifi:无线路由器,ap5G 1.2. 常见网线类型 1.2.1. 双绞线(Twisted Pair Cable)【最常用】 按性能主…

VR/AR 视网膜级显示破局:10000PPI 如何终结颗粒感时代?

一、传统液晶 “纱窗效应”:VR 沉浸体验的最大绊脚石 当用户首次戴上 VR 头显时,眼前密密麻麻的像素网格往往打破沉浸感 —— 这正是传统液晶显示在近眼场景下的致命缺陷。受限于 500-600PPI 的像素密度,即使达到 4K 分辨率,等效到…

2022---不重复版的数的划分-且范围太大

1.数的划分--数的划分--dfs剪枝-CSDN博客 2.范围太大&#xff0c;这题用dp 3.状态转移公式其中1是泛指 #include<bits/stdc.h> using namespace std; #define N 100011 typedef long long ll; typedef pair<int,int> pii; ll dp[2025][12]; int n,k; void solv…

day15 leetcode-hot100-29(链表8)

19. 删除链表的倒数第 N 个结点 - 力扣&#xff08;LeetCode&#xff09; 1.暴力法 思路 &#xff08;1&#xff09;先获取链表的长度L &#xff08;2&#xff09;然后再次遍历链表到L-n的位置&#xff0c;直接让该指针的节点指向下下一个即可。 2.哈希表 思路 &#xff0…

WIN32-内存管理

分配内存-VirtualAlloc 他与malloc和new的不同在于VirtualAlloc是真正意义上的开辟的一片内存 而且它可以为开辟出来的内存指定属性 LPVOID VirtualAlloc([in, optional] LPVOID lpAddress,[in] SIZE_T dwSize,[in] DWORD flAllocationType,[in] …

线程概念与控制

目录 Linux线程概念 什么是线程 分页式存储管理 虚拟地址和页表的由来 物理内存管理 页表 缺页异常 线程的优点 线程的缺点 线程异常 Linux进程VS线程 进程与线程 进程的多个线程共享 进程与线程关系如图 Linux线程控制 POSIX线程库 创建线程 测试 获取线程…