Torch -- 卷积学习day4 -- 完整项目流程

完整项目流程总结

1. 环境准备与依赖导入

import time
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
import wandb
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import *
import matplotlib.pyplot as plt

2. 数据准备与增强

# 数据增强变换
transform = transforms.Compose([transforms.RandomRotation(45),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
​
# 测试集变换
transformtest = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])
​
# 数据集加载
train_dataset = CIFAR10(root=datapath,train=True,download=True,transform=transform,
)
​
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,
)

3. 模型构建与初始化

# 获取ResNet18模型并调整全连接层
model = resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features=in_features, out_features=10)
​
# 加载预训练权重(如果有)
if os.path.exists(weightpath):weights_default = torch.load(weightpath)weights_default.pop("fc.weight", None)weights_default.pop("fc.bias", None)new_state_dict = model.state_dict()weights_default_process = {k: v for k, v in weights_default.items() if k in new_state_dict}new_state_dict.update(weights_default_process)model.load_state_dict(new_state_dict)
​
model.to(device)

4. 训练过程

# 初始化训练工具
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
​
# 可视化工具初始化
wandb.init(project="my-qianyi-project", config={...})
write1 = SummaryWriter(log_dir=log_dir)
write1.add_graph(model, input_to_model=torch.randn(1, 3, 32, 32).to(device))
​
# 训练循环
for epoch in range(epochs):model.train()# 训练代码...torch.save(model.state_dict(), weightpath)

5. 验证与评估

# 加载最佳模型进行验证
model.load_state_dict(torch.load(weightpath))
model.eval()
​
# 验证过程
# 保存预测结果到CSV
# 生成分类报告和混淆矩阵

6. 模型应

# 加载模型进行推理
def predict_image(image_path):# 图像预处理# 模型预测# 返回结果

7. 模型移植与部署

7.1 模型转换(PyTorch → ONNX/)

python

# 转换为ONNX格式
def convert_to_onnx(model, input_size, onnx_path):model.eval()dummy_input = torch.randn(1, *input_size).to(device)torch.onnx.export(model,dummy_input,onnx_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"Model converted to ONNX and saved to {onnx_path}")
​
# 使用示例
convert_to_onnx(model, (3, 32, 32), "model.onnx")
7.2 模型量化(减小模型大小,加速推理)

python

# 动态量化
def quantize_model(model):quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)return quantized_model
​
# 使用示例
quantized_model = quantize_model(model)
torch.save(quantized_model.state_dict(), "quantized_model.pth")
7.3 减少参数数量
# 简单的权重剪枝
def prune_model(model, pruning_percentage=0.2):parameters_to_prune = []for name, module in model.named_modules():if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):parameters_to_prune.append((module, 'weight'))torch.nn.utils.prune.global_unstructured(parameters_to_prune,pruning_method=torch.nn.utils.prune.L1Unstructured,amount=pruning_percentage,)return model
​
# 使用示例
pruned_model = prune_model(model)
7.4 移动端部署(使用ONNX Runtime)
# 保存为LibTorch格式(C++可用)
example = torch.rand(1, 3, 32, 32).to(device)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
7.5 Web部署(使用ONNX.js)
# 首先转换为ONNX,然后使用ONNX.js在浏览器中运行
# 或者使用第三方工具如https://github.com/onnx/tensorflow-onnx
7.6 边缘设备部署(使用TensorRT、OpenVINO等)
# 使用NVIDIA TensorRT优化(需要先转换为ONNX)
# 或使用Intel OpenVINO工具包

8. 性能监控与优化

# 模型推理速度测试
def benchmark_model(model, input_size, num_runs=100):model.eval()input_tensor = torch.randn(1, *input_size).to(device)# GPU预热for _ in range(10):_ = model(input_tensor)# 计时start_time = time.time()for _ in range(num_runs):_ = model(input_tensor)end_time = time.time()avg_time = (end_time - start_time) / num_runsfps = 1 / avg_timeprint(f"Average inference time: {avg_time*1000:.2f} ms, FPS: {fps:.2f}")return avg_time, fps
​
# 使用示例
benchmark_model(model, (3, 32, 32))

这个完整的流程涵盖了从数据准备到模型部署的全过程,特别是新增的模型移植部分,提供了将训练好的模型部署到不同平台和设备的方法,这对于实际应用非常重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93926.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93926.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MTK Linux DRM分析(七)- KMS drm_plane.c

一、简介在 Linux DRM(Direct Rendering Manager)子系统中,Plane(平面)代表了一个图像源,可以在扫描输出过程中与 CRTC 混合或叠加显示。每个 Plane 从 drm_framebuffer 中获取输入数据,并负责图…

OpenHarmony之 蓝牙子系统全栈剖析:从协议栈到芯片适配的端到端实践(大合集)

1. 系统架构概述 OpenHarmony蓝牙系统采用分层架构设计,基于HDF(Hardware Driver Foundation)驱动框架和系统能力管理(System Ability)机制实现。 1.1 架构层次 ┌─────────────────────────…

探索 Ultralytics YOLOv8标记图片

1、下载YOLOv8模型文件 下载地址:https://docs.ultralytics.com/zh/models/yolov8/#performance-metrics 2、编写python脚本 aaa.py import cv2 import numpy as np from ultralytics import YOLO import matplotlib.pyplot as pltdef plot_detection(image, box…

Matplotlib数据可视化实战:Matplotlib子图布局与管理入门

Matplotlib多子图布局实战 学习目标 通过本课程的学习,学员将掌握如何在Matplotlib中创建和管理多个子图,了解子图布局的基本原理和调整方法,能够有效地展示多个数据集,提升数据可视化的效果。 相关知识点 Matplotlib子图 学习内容…

【python实用小脚本-194】Python一键给PDF加水印:输入文字秒出防伪文件——再也不用开Photoshop

Python一键给PDF加水印:输入文字秒出防伪文件——再也不用开Photoshop PDF加水印, 本地脚本, 零会员费, 防伪标记, 瑞士军刀 故事开场:一把瑞士军刀救了投标的你 周五下午,你把 100 页标书 PDF 发给客户,却担心被同行盗用。 想加水…

开源 C++ QT Widget 开发(四)文件--二进制文件查看编辑

文章的目的为了记录使用C 进行QT Widget 开发学习的经历。临时学习,完成app的开发。开发流程和要点有些记忆模糊,赶紧记录,防止忘记。 相关链接: 开源 C QT Widget 开发(一)工程文件结构-CSDN博客 开源 C…

【密码学实战】X86、ARM、RISC-V 全量指令集与密码加速技术全景解析

前言 CPU 指令集是硬件与软件交互的核心桥梁,其设计直接决定计算系统的性能边界与应用场景。在数字化时代,信息安全依赖密码算法的高效实现,而指令集扩展则成为密码加速的 “隐形引擎”—— 从服务器端的高吞吐量加密,到移动端的…

2025-08-21 Python进阶2——数据结构

文章目录1 列表(List)1.1 列表常用方法1.2 列表的特殊用途1.2.1 实现堆栈(后进先出)1.2.2 实现队列(先进先出)1.3 列表推导式1.4 嵌套列表推导式2 del 语句3 元组(Tuple)4 集合&…

告别手工编写测试脚本!Claude+Playwright MCP快速生成自动化测试脚本

在进行自动化测试时,前端页面因为频繁迭代UI 结构常有变动,这往往使得自动化测试的脚本往往“写得快、废得也快”,维护成本极高。在大模型之前大家往往都会使用录制类工具,但录制类工具生成的代码灵活性较差、定位方式不太合理只能…

一款更适合 SpringBoot 的API文档新选择(Spring Boot 应用 API 文档)

SpringDoc:Spring Boot 应用 API 文档生成的现代化解决方案 概述 SpringDoc 是一个专为 Spring Boot 应用设计的开源库,能够自动生成符合 OpenAPI 3 规范的 API 文档。它通过扫描项目中的控制器、方法注解及相关配置,动态生成 JSON/YAML/HTML…

文献阅读 250821-When and where soil dryness matters to ecosystem photosynthesis

When and where soil dryness matters to ecosystem photosynthesis 来自 <When and where soil dryness matters to ecosystem photosynthesis | Nature Plants> ## Abstract: Background: Projected increases in the intensity and frequency of droughts in the twen…

React学习(九)

目录&#xff1a;1.react-进阶-antd-新增2.react-进阶-antd-删除选中1.react-进阶-antd-新增新增代码&#xff0c;跟需改的代码类似&#xff0c;直接copy修改组件代码进行修改userEffect可以先带着&#xff0c;没啥用A6组件用到的函数跟修改的也类似&#xff1a;这个useEffect函…

零基础从头教学Linux(Day 17)

三层交换机一、三层交换机的配置1.关于如何配置三层交换机&#xff0c;首先我们应该先创建VLANSwitch>en Switch#vlan database % Warning: It is recommended to configure VLAN from config mode,as VLAN database mode is being deprecated. Please consult userdocument…

任务十四 推荐页面接口开发

一、接口准备 在对接qq音乐接口之前,首先要将之前的项目,一定要记得备份一份; 备份完成之后,首先要在vscode终端安装axios,这个是请求后端的工具,和之前的ajax一样,都是请求后端的工具。只不过axios更专业化,跟强大 至于qq音乐接口怎么获取,一般有两个途径,第一个是…

医疗AI与医院数据仓库的智能化升级:异构采集、精准评估与高效交互的融合方向(下)

核心功能创新详解: 统一门户与角色化工作台: 统一入口: 用户通过单一URL登录,系统根据其角色和权限自动呈现专属工作台。 角色化工作台: 临床医生工作台: 首屏展示常用患者查询入口、快速统计(如“我的患者检验异常趋势”)、相关临床文献推荐、待处理任务(如报告审核)…

数据库面试常见问题

数据库 Delete Truncate Drop 区别 答:这三个操作都是针对数据库的表进行操作,都有删除表的功能,其中的区别在于: Delete:只将表中的数据进行删除,不删除定义不释放空间,是dml语句,需要提交事务,如果不想删除可以回滚。delete每次删除一行,并在事务日志中为所删除…

用nohup setsid绕过超时断连,稳定反弹Shell

在We渗透过程中&#xff0c;我们常常会利用目标系统的远程代码执行&#xff08;RCE&#xff09;漏洞进行反弹Shell。然而&#xff0c;由于Web服务器&#xff08;如PHP、Python后端&#xff09;的执行环境通常存在超时限制&#xff08;如max_execution_time或进程管理策略&#…

Java设计模式-模板方法模式

Java设计模式-模板方法模式 模式概述 模板方法模式简介 核心思想&#xff1a;定义一个操作中的算法骨架&#xff08;模板方法&#xff09;&#xff0c;将算法中某些步骤的具体实现延迟到子类中完成。子类可以在不改变算法整体结构的前提下&#xff0c;重定义这些步骤的行为&…

Centos7物理安装 Redis8.2.0

Centos7物理安装 Redis8.2.0一、准备依赖环境首先安装编译 Redis 所需的依赖&#xff1a;# CentOS/RHEL系统 yum install -y gcc gcc-c make wget 二、下载并编译 Redis 8.2.0# 1. 下载Redis 8.2.0源码包 wget https://download.redis.io/releases/redis-8.2.0.tar.gz# 2. 解压…

牛津大学xDeepMind 自然语言处理(3)

条件语言模型无条件语言模型 概率计算&#xff1a;通过链式法则分解为预测下一词概率&#xff08;将语言建模问题简化为建模给定前面词语历史的下一个词的概率&#xff09;基于循环神经网络的无条件语言模型&#xff1a;根据历史词语预测下一个词的概率条件语言模型 定义&#…