Python打卡第52天

@浙大疏锦行

作业:

对于day'41的简单cnn,看看是否可以借助调参指南进一步提高精度。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 改进数据预处理 - 增加数据增强
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),  # 随机裁剪transforms.RandomHorizontalFlip(),     # 随机水平翻转transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # 使用CIFAR-10的真实统计数据
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=transform_test
)# 3. 增加Batch Size - 从64增加到128
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)# 4. 改进模型结构 - 更深、更宽的网络
class ImprovedMLP(nn.Module):def __init__(self):super(ImprovedMLP, self).__init__()self.flatten = nn.Flatten()# 增加网络宽度和深度self.layer1 = nn.Linear(3072, 1024)self.bn1 = nn.BatchNorm1d(1024)  # 添加批归一化self.relu1 = nn.ReLU()self.dropout1 = nn.Dropout(0.3)self.layer2 = nn.Linear(1024, 1024)self.bn2 = nn.BatchNorm1d(1024)self.relu2 = nn.ReLU()self.dropout2 = nn.Dropout(0.3)self.layer3 = nn.Linear(1024, 512)self.bn3 = nn.BatchNorm1d(512)self.relu3 = nn.ReLU()self.dropout3 = nn.Dropout(0.3)self.layer4 = nn.Linear(512, 256)self.bn4 = nn.BatchNorm1d(256)self.relu4 = nn.ReLU()self.dropout4 = nn.Dropout(0.3)self.layer5 = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.bn1(x)x = self.relu1(x)x = self.dropout1(x)x = self.layer2(x)x = self.bn2(x)x = self.relu2(x)x = self.dropout2(x)x = self.layer3(x)x = self.bn3(x)x = self.relu3(x)x = self.dropout3(x)x = self.layer4(x)x = self.bn4(x)x = self.relu4(x)x = self.dropout4(x)x = self.layer5(x)return x# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ImprovedMLP()
model = model.to(device)# 5. 优化器与学习率调度 - 使用学习率预热和余弦退火
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)  # 添加L2正则化
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)  # 余弦退火调度器# 6. 早停机制
class EarlyStopping:def __init__(self, patience=10, delta=0):self.patience = patienceself.delta = deltaself.counter = 0self.best_score = Noneself.early_stop = Falsedef __call__(self, val_acc):score = val_accif self.best_score is None:self.best_score = scoreelif score < self.best_score + self.delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.counter = 0return self.early_stop# 7. 训练模型(改进版,记录训练和验证准确率)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()# 记录每个 epoch 的准确率和损失train_acc_history = []train_loss_history = []test_acc_history = []test_loss_history = []# 早停实例early_stopping = EarlyStopping(patience=15)for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100个批次打印一次训练信息if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 损失: {loss.item():.4f} | 准确率: {100.*correct/total:.2f}%')# 计算当前epoch的平均训练损失和准确率epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_loss_history.append(epoch_train_loss)train_acc_history.append(epoch_train_acc)# 测试阶段model.eval()test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_loss_history.append(epoch_test_loss)test_acc_history.append(epoch_test_acc)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 更新学习率scheduler.step()# 检查早停if early_stopping(epoch_test_acc):print(f"早停触发!在 epoch {epoch+1} 停止训练")break# 绘制训练和测试准确率曲线plot_accuracy(train_acc_history, test_acc_history, epochs)# 绘制训练和测试损失曲线plot_loss(train_loss_history, test_loss_history, epochs)return epoch_test_acc, epoch_test_loss# 8. 绘制准确率曲线
def plot_accuracy(train_acc, test_acc, epochs):plt.figure(figsize=(10, 5))plt.plot(range(1, len(train_acc)+1), train_acc, 'b-', label='训练准确率')plt.plot(range(1, len(test_acc)+1), test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和测试准确率')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 9. 绘制损失曲线
def plot_loss(train_loss, test_loss, epochs):plt.figure(figsize=(10, 5))plt.plot(range(1, len(train_loss)+1), train_loss, 'b-', label='训练损失')plt.plot(range(1, len(test_loss)+1), test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失')plt.title('训练和测试损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 10. 执行训练和测试
epochs = 200  # 增加训练轮次
print("开始训练模型...")
final_accuracy, final_loss = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}% | 最终测试损失: {final_loss:.4f}")# 保存模型
torch.save(model.state_dict(), 'cifar10_improved_mlp_model.pth')
print("模型已保存为: cifar10_improved_mlp_model.pth")

训练完成!最终测试准确率: 93.98%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/83465.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣100- 环形链表

方法一 遍历 循环链表&#xff0c;查找链表节点是否重复出现 public boolean hasCycle(ListNode head) {Set<ListNode> set new HashSet<>(); if (head null) return false; while (head ! null) {if (set.contains(head)) {return true;}set.add(head);head …

Java + Spring Boot + Mybatis 插入数据后,获取自增 id 的方法

在 MyBatis 中使用 useGeneratedKeys"true" 获取新插入记录的自增 ID 值&#xff0c;可通过以下步骤实现&#xff1a; 1. 配置 Mapper XML 在插入语句的 <insert> 标签中设置&#xff1a; xml 复制 下载 运行 <insert id"insertUser" para…

Meta发布V-JEPA 2世界模型及物理推理新基准,推动AI在物理世界中的认知与规划能力

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

触觉智能RK3576核心板工业应用之软硬件全国产化,成功适配开源鸿蒙OpenHarmony5.0

在全球科技竞争加剧和供应链安全日益重要的背景下&#xff0c;实现关键软硬件的全国产化替代已成为国家战略和产业共识。在这一背景下&#xff0c;触觉智能推出RK3576核心板,率先适配开源鸿蒙OpenHarmony5.0操作系统&#xff0c;真正实现了从芯片到操作系统的全栈国产化方案&am…

前端基础知识ES6系列 - 01(var、let、const之间的区别)

一、var 在ES5中&#xff0c;顶层对象的属性和全局变量是等价的&#xff0c;用var声明的变量既是全局变量&#xff0c;也是顶层变量 注意&#xff1a;顶层对象&#xff0c;在浏览器环境指的是window对象&#xff0c;在 Node 指的是global对象 var a 10; console.log(window…

Python Docker 镜像构建完整指南:从基础到优化

Python 是一门广泛使用的编程语言,在容器化环境中,构建和使用 Python 镜像是非常常见的任务。本文将提供一个完整的指南,包括选择基础镜像、制作流程、不同场景下的应用、安全性最佳实践以及镜像优化策略。 1. 选择合适的基础镜像 1.1 官方 Python 镜像 Docker Hub 提供了…

【狂飙AGI】第1课:大模型概述

目录 &#xff08;一&#xff09;大模型概念解析&#xff08;二&#xff09;大模型发展历程&#xff08;三&#xff09;大模型发展现状&#xff08;1&#xff09;OpenAI&#xff08;2&#xff09;微软&#xff08;3&#xff09;谷歌&#xff08;4&#xff09;Meta &#xff08;…

vite ts 配置使用@ 允许js

1.vite.config.ts 配置 import { defineConfig } from vite import vue from vitejs/plugin-vue import { fileURLToPath, URL } from node:url import setup_extend from vite-plugin-vue-setup-extend// https://vite.dev/config/ export default defineConfig({plugins: …

使用Ollama+open-webui搭建本地AI模型

本地搭建AI模型 说明&#xff1a;1、下载Ollama2、下载模型3、pip安装open-webui&#xff08;不推荐&#xff09;1、Python版本不对应2、下载wheels失败 4、docker安装open-webui 说明&#xff1a; 在windows上搭建本地AI&#xff0c;使用Ollamaopen-webui的方式&#xff0c;可…

第 87 场周赛:比较含退格的字符串、数组中的最长山脉、一手顺子、访问所有节点的最短路径

Q1、[简单] 比较含退格的字符串 1、题目描述 给定 s 和 t 两个字符串&#xff0c;当它们分别被输入到空白的文本编辑器后&#xff0c;如果两者相等&#xff0c;返回 true 。# 代表退格字符。 **注意&#xff1a;**如果对空文本输入退格字符&#xff0c;文本继续为空。 示例 …

linux安装阿里DataX实现数据迁移

目录 下载datax工具包(如果下载慢&#xff0c;请尝试其他国内镜像站或其他网站下载相应资源) 解压工具包到当前目录里 接着进入conf配置目录并创建一个myjob.json&#xff08;临时测试json&#xff09;&#xff0c;myjob.json内容如下&#xff0c;用于模拟test库tab1表数据同…

C++ 引用介绍

很好&#xff01;既然你有 C 的基础&#xff0c;那么理解 C 的「引用&#xff08;reference&#xff09;」会容易很多。我们来一步步讲清楚这个概念。 &#x1f31f; 一句话总结&#xff1a; C 引用&#xff08;reference&#xff09;就是已存在变量的“别名”&#xff0c;它不…

学习笔记086——@PostConstruct注解和InitializingBean接口的使用

文章目录 1、PostConstruct注解1.1 介绍1.2 用法1.3 场景 2、InitializingBean接口2.1 介绍2.2 用法 1、PostConstruct注解 1.1 介绍 PostConstruct 是 Java EE/Jakarta EE 中的一个注解&#xff0c;用于标记一个方法在依赖注入完成后执行初始化操作。它通常与 Spring 框架一…

考研系列—408真题操作系统篇(2015-2019)

目录 # 2015年 1.死锁处理 (1)预防死锁 (2)避免死锁 (3)死锁检测和解除 2.请求分页系统的页面置换策略、页面置换策略 3.页、页框、页表,基本分页系统 # 2016年 1.异常、中断 2.页置换算法 3.进程的互斥操作 4.SPOOLing技术(从软件方面实现设备共享) 5.一定要牢记…

argocd部署cli工具并添加k8s集群

先决条件: 1.已经有k8s集群,(网上一万种部署方式,这里我使用的是kubekey部署的),也埋了个坑,后面说明. 2.已经部署好argocd,并验证web已经可以访问.参见 k8s部署argocd-CSDN博客 部署客户端工具, 这里我是从web页面上直接下载的对应版本的cli工具. 打开已经部署好的argoc…

打卡day52

简单cnn 借助调参指南进一步提高精度 基础CNN模型代码 import tensorflow as tf from tensorflow.keras import layers, models from tensorflow.keras.datasets import cifar10 from tensorflow.keras.utils import to_categorical# 加载数据 (train_images, train_labels),…

OpenGL ES绘制3D图形以及设置视口

文章目录 关于 glDrawElements基本概念使用场景mode 绘制模式type 索引数据类型indices 索引缓冲区工作原理绘制正方体实例 视口透视投影&#xff08;Perspective Projection&#xff09;正交投影&#xff08;Orthographic Projection&#xff09;正交投影和透视投影对比 关于 …

【SAS求解多元回归方程】REG多元回归分析-多元一次回归

多元一次回归是一种统计方法&#xff0c;用于分析多个自变量&#xff08;解释变量&#xff09;与一个因变量&#xff08;响应变量&#xff09;之间的线性关系。 目录 【示例】 基本语法 SAS代码 参数估计 方差分析 回归统计量 y的拟合诊断 y的回归变量值 【示例】 设Y…

卡通幼儿园教育通用可爱PPT模版分享

幼儿园教育通用PPT模版&#xff0c;教育教学PPT模版&#xff0c;卡通教育PPT模版&#xff0c;可爱卡通教学课件PPT模版&#xff0c;小清新动物卡通通用PPT模版&#xff0c;教学说课通用PPT模版&#xff0c;开学季PPT模版&#xff0c;国学颂歌PPT模版&#xff0c;可爱简约风PPT模…

力扣HOT100之技巧:75. 颜色分类

这道题实际上就是让我们不用sort()函数来实现对原数组的排序&#xff0c;这里我直接使用快速排序对原数组进行排序了&#xff0c;也是复习一下基于快慢指针的快速排序写法。面试手撕快排的思路参考这个视频。 用时击败100%&#xff0c;还行。下面直接贴代码。 class Solution …