如何保存训练的最优模型和使用最优模型文件

一 保存最优模型

主要就是我们在for循环中加上一个test测试,并且我还在test函数后面加上了返回值,可以返回准确率,然后每次进行一次对比,然后取大的。然后这里有两种保存方式,一种是保存了整个模型,另一个是保存了模型参数。

1 仅保存模型参数

torch.save(model.state_dict(),'best_model.pth')

然后后面我们使用的时候

model =torch.load('best1.pth')#
model.to(device)
model.load_state_dict(torch.load("best.pth", map_location=device))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

注意这里要设置eval模式,因为我们要保证我们的模型参数不再变化了。

2 保存整个模型

torch.save(model,'best1.pth')

在调用的时候

model = torch.load('best1.pth', map_location=torch.device('cuda'))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

直接调用就好。

注意这两种必须要有定义好的网络,不然无法运行(保存整个网络也要定于一个完全相同的网络)。

完整代码

epochs=20
for i in range(epochs):print(f"Epoch {i+1}")train(train_dataloader,model,loss_fn,optimizer)corrects = test(test_dataloader,model,loss_fn)accuracy_list.append(corrects)if corrects>best_acc:print(f"Best Accuracy: {corrects}%")best_acc=corrects#第一种# torch.save(model.state_dict(),'best_model.pth')#第二种torch.save(model,'best1.pth')

完整代码含网络

import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transformsclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(  # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=3,  # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数),out_channels=16,  # 要得到多少个特征图,卷积核的个数kernel_size=5,  # 卷积核人小,5*5stride=1,  # 步长padding=2  # 填充值),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64 * 64 * 64, 20)  # 全连接层得到的结果def forward(self, x):  # 前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # view和reshape是一样的作用,但此处是tensor形式output = self.out(x)return outputdata_transform={# 'train': transforms.Compose([#     # 调整图像大小为300x300像素#     transforms.Resize([256, 256]),##     # # 随机旋转:-45到45度之间随机选择角度#     # transforms.RandomRotation(45),#     # ##     # # # 从中心裁剪出256x256的区域#     # transforms.CenterCrop([256, 256]),#     ##     # # 随机水平翻转:以50%的概率进行水平镜像#     # transforms.RandomHorizontalFlip(p=0.5),#     ##     # # 随机垂直翻转:以50%的概率进行垂直镜像#     # transforms.RandomVerticalFlip(p=0.5),#     ##     # # # 颜色抖动:随机调整亮度、对比度、饱和度和色调#     # # transforms.ColorJitter(#     # #     brightness=0.2,    # 亮度变化幅度为20%#     # #     contrast=0.1,      # 对比度变化幅度为10%#     # #     saturation=0.1,    # 饱和度变化幅度为10%#     # #     hue=0.1            # 色调变化幅度为10%#     # # ),#     # ##     # # # 随机灰度化:以10%的概率将图像转换为灰度图#     # transforms.RandomGrayscale(p=0.1),##     # 将PIL图像转换为PyTorch张量,并自动归一化到[0,1]范围#     transforms.ToTensor(),##     # 标准化:使用ImageNet数据集的均值和标准差进行标准化#     transforms.Normalize(#         [0.485, 0.456, 0.406],  # 均值(R, G, B通道)#         [0.229, 0.224, 0.225]   # 标准差(R, G, B通道)#     )# ]),# 验证/测试数据的预处理(通常不需要数据增强)'test': transforms.Compose([transforms.Resize([256, 256]),# transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):def __init__(self, root, transform=None):super().__init__()self.root = rootself.transform = transformself.images = []self.labels = []with open(root,encoding='utf-8') as f:samples = [i.strip().split() for i in f.readlines()]for img_path,label in samples:self.images.append(img_path)self.labels.append(label)def __len__(self):return len(self.images)def __getitem__(self, index):image=Image.open(self.images[index]).convert('RGB')if self.transform:image=self.transform(image)label = self.labels[index]# print(label)label = torch.from_numpy(np.array(label,dtype=np.int64))# print(label)return image, labeldef test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()batch_size_num=1loss,correct=0,0with torch.no_grad():for X, y in test_dataloader:X,y=X.to(device),y.to(device)pred = model(X)loss = loss_fn(pred,y)+losscorrect += (pred.argmax(1) == y).type(torch.float).sum().item()loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%,Avg loss: {loss}')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")test_data=food_dataset('test_data',transform=(data_transform['test']))
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)# model =CNN()
# model.to(device)
# model.load_state_dict(torch.load("best.pth"))
model=torch.load('best.pt')
model.eval()
loss_fn = nn.CrossEntropyLoss()
test(test_dataloader,model,loss_fn)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/95704.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/95704.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3+ts+echarts多Y轴折线图

因为放在了子组件才监听&#xff0c;加载渲染调用&#xff0c;有暗黑模式才调用&#xff0c;<!-- 温湿度传感器 --><el-row v-if"deviceTypeId 2"><el-col :xs"24" :sm"24" :md"24" :lg"24" :xl"24&qu…

基于Taro4打造的一款最新版微信小程序、H5的多端开发简单模板

基于Taro4、Vue3、TypeScript、Webpack5打造的一款最新版微信小程序、H5的多端开发简单模板 特色 &#x1f6e0;️ Taro4, Vue 3, Webpack5, pnpm10 &#x1f4aa; TypeScript 全新类型系统支持 &#x1f34d; 使用 Pinia 的状态管理 &#x1f3a8; Tailwindcss4 - 目前最流…

ITU-R P.372 无线电噪声预测库调用方法

代码功能概述&#xff08;ITURNoise.c&#xff09;该代码是一个 ITU-R P.372 无线电噪声预测 的计算程序&#xff0c;能够基于 月份、时间、频率、地理位置、人为噪声水平 计算特定地点的 大气噪声、银河噪声、人为噪声及其总和&#xff0c;并以 CSV 或标准输出 方式提供结果。…

《从报错到运行:STM32G4 工程在 Keil 中的头文件配置与调试实战》

《从报错到运行&#xff1a;STM32G4 工程在 Keil 中的头文件配置与调试实战》文章提纲一、引言• 阐述 STM32G4 在嵌入式领域的应用价值&#xff0c;说明 Keil 是开发 STM32G4 工程的常用工具• 指出头文件配置是 STM32G4 工程在 Keil 中开发的关键基础环节&#xff0c;且…

Spring 事务提交成功后执行额外逻辑

1. 场景与要解决的问题在业务代码里&#xff0c;常见诉求是&#xff1a;只有当数据库事务真正提交成功后&#xff0c;才去执行某些“后置动作”&#xff0c;例如&#xff1a;发送 MQ、推送消息、写审计/埋点日志、刷新缓存、通知外部系统等。如果这些动作在事务提交前就执行&am…

Clickhouse MCP@Mac+Cherry Studio部署与调试

一、需求背景 已经部署测试了Mysql、Drois的MCP Server,想进一步测试Clickhouse MCP的表现。 二、环境 1)操作系统 MacOS+Apple芯片 2)Clickhouse v25.7.6.21-stable、Clickhouse MCP 0.1.11 3)工具Cherry Studio 1.5.7、Docker Desktop 4.43.2(199162) 4)Python 3.1…

Java Serializable 接口:明明就一个空的接口嘛

对于 Java 的序列化,我之前一直停留在最浅层次的认知上——把那个要序列化的类实现 Serializbale 接口就可以了嘛。 我似乎不愿意做更深入的研究,因为会用就行了嘛。 但随着时间的推移,见到 Serializbale 的次数越来越多,我便对它产生了浓厚的兴趣。是时候花点时间研究研…

野火STM32Modbus主机读取寄存器/线圈失败(三)-尝试将存贮事件的地方改成数组(非必要解决方案)(附源码)

背景 尽管crc校验正确了&#xff0c;也成功发送了EV_MASTER_EXECUTE事件&#xff0c;但是eMBMasterPoll( void )中总是接收的事件是EV_MASTER_FRAME_RECEIVED或者EV_MASTER_FRAME_SENT&#xff0c;一次都没有执行EV_MASTER_EXECUTE。EV_MASTER_EXECUTE事件被别的事件给覆盖了&…

微信小程序校园助手程序(源码+文档)

源码题目&#xff1a;微信小程序校园助手程序&#xff08;源码文档&#xff09;☑️ 文末联系获取&#xff08;含源码、技术文档&#xff09;博主简介&#xff1a;10年高级软件工程师、JAVA技术指导员、Python讲师、文章撰写修改专家、Springboot高级&#xff0c;欢迎高校老师、…

59-python中的类和对象、构造方法

1. 认识一下对象 世间万物皆是"对象" student_1{ "姓名":"小朴", "爱好":"唱、跳、主持" ......... }白纸填写太落伍了 设计表格填写先进一些些 终极目标是程序使用对象去组织数据程序中设计表格&#xff0c;我们称为 设计类…

向成电子惊艳亮相2025物联网展,携工控主板等系列产品引领智造新风向

2025年8月27-29日&#xff0c;IOTE 2025 第二十四届国际物联网展深圳站在深圳国际会展中心&#xff08;宝安&#xff09;盛大启幕&#xff01;作为全球规模领先的物联网盛会之一&#xff0c;本届展会以“生态智能&#xff0c;物联全球”为核心&#xff0c;汇聚超1000家全球头部…

阵列信号处理之均匀面阵波束合成方向图的绘制与特点解读

阵列信号处理之均匀面阵波束合成方向图的绘制与特点解读 文章目录前言一、方向图函数二、方向图绘制三、副瓣电平四、阵元个数对主瓣宽度的影响五、阵元间距对主瓣宽度的影响六、MATLAB源代码总结前言 \;\;\;\;\;均匀面阵&#xff08;Uniform Planar Array&#xff0c;UPA&…

算法在前端框架中的集成

引言 算法是前端开发中提升性能和用户体验的重要工具。随着 Web 应用复杂性的增加&#xff0c;现代前端框架如 React、Vue 和 Angular 提供了强大的工具集&#xff0c;使得将算法与框架特性&#xff08;如状态管理、虚拟 DOM 和组件化&#xff09;无缝集成成为可能。从排序算法…

网络爬虫是自动从互联网上采集数据的程序

网络爬虫是自动从互联网上采集数据的程序网络爬虫是自动从互联网上采集数据的程序&#xff0c;Python凭借其丰富的库生态系统和简洁语法&#xff0c;成为了爬虫开发的首选语言。本文将全面介绍如何使用Python构建高效、合规的网络爬虫。一、爬虫基础与工作原理 网络爬虫本质上是…

Qt Model/View/Delegate 架构详解

Qt Model/View/Delegate 架构详解 Qt的Model/View/Delegate架构是Qt框架中一个重要的设计模式&#xff0c;它实现了数据存储、数据显示和数据编辑的分离。这种架构不仅提高了代码的可维护性和可重用性&#xff0c;还提供了极大的灵活性。 1. 架构概述 Model/View/Delegate架构将…

光谱相机在手机行业的应用

在手机行业&#xff0c;光谱相机技术通过提升拍照色彩表现和扩展健康监测等功能&#xff0c;正推动摄像头产业链升级&#xff0c;并有望在AR/VR、生物医疗等领域实现更广泛应用。以下为具体应用场景及技术突破的详细说明&#xff1a;‌一、光谱相机在手机行业的应用场景‌‌拍照…

FASTMCP中的Resources和Templates

Resources 给 MCP 客户端/LLM 读取的数据端点&#xff08;只读、按 URI 索引、像“虚拟文件系统”或“HTTP GET”&#xff09;&#xff1b; Templates 可带参数的资源路由&#xff08;URI 里占位符 → 运行函数动态生成内容&#xff09;。 快速要点 • 用途&#xff1a;把文件…

OpenBMC之编译加速篇

加快 OpenBMC 的编译速度是一个非常重要的话题,因为完整的构建通常非常耗时(在高性能机器上也需要数十分钟,普通电脑上可能长达数小时)。以下是从不同层面优化编译速度的详细策略,您可以根据自身情况组合使用。 一、核心方法:利用 BitBake 的缓存和共享机制(效果最显著…

Kafka面试精讲 Day 8:日志清理与数据保留策略

【Kafka面试精讲 Day 8】日志清理与数据保留策略 在Kafka的高吞吐、持久化消息系统中&#xff0c;日志清理与数据保留策略是决定系统资源利用效率、数据可用性与合规性的关键机制。作为“Kafka面试精讲”系列的第8天&#xff0c;本文聚焦于日志清理机制&#xff08;Log Cleani…

基于Hadoop的网约车公司数据分析系统设计(代码+数据库+LW)

摘 要 本系统基于Hadoop平台&#xff0c;旨在为网约车公司提供一个高效的数据分析解决方案。随着网约车行业的快速发展&#xff0c;平台上产生的数据量日益增加&#xff0c;传统的数据处理方式已无法满足需求。因此&#xff0c;设计了一种基于Hadoop的大规模数据处理和分析方…