从 PyTorch 到 TensorFlow Lite:模型训练与推理

一、方案介绍

  1. 研发阶段:利用 PyTorch 的动态图特性进行快速原型验证,快速迭代模型设计。
    • 灵活性与易用性:PyTorch 是一个非常灵活且易于使用的深度学习框架,特别适合研究和实验。其动态计算图特性使得模型的构建和调试变得更加直观,开发者可以在运行时修改模型结构。
    • 快速原型开发:许多研究人员和开发者选择 PyTorch 进行模型训练,因为它支持快速原型开发和灵活的模型设计,能够快速验证新想法并进行迭代。
  2. 转换阶段:将训练好的模型通过 TorchScript 导出为 ONNX 格式,再转换为 TensorFlow 格式,最后生成 TFLite 模型。
    • 专为移动和嵌入式设备优化:TensorFlow Lite 是专为移动和嵌入式设备设计的推理框架,能够在资源有限的环境中高效运行模型,确保在各种设备上实现实时推理。
    • 支持模型量化和优化:TFLite 支持模型量化和优化,能够显著减小模型大小并提高推理速度,适合在手机、边缘设备等场景中使用。这使得开发者能够在不牺牲准确度的情况下,提升模型的运行效率。
  3. 部署阶段:将 TFLite 模型集成到 Android、iOS 或嵌入式系统中,确保模型能够在目标设备上高效运行。
    • 内存和计算资源的优化:在推理阶段,使用 TFLite 可以减少内存占用和计算资源消耗,尤其是在移动设备和嵌入式系统上。这对于需要长时间运行的应用尤为重要,可以延长设备的电池寿命。
    • 多种优化技术:TFLite 提供了多种优化技术,如模型量化(将浮点数转换为整数),可以进一步提高推理速度并降低功耗。这使得在实时应用中能够实现更快的响应时间,提升用户体验。
      在这里插入图片描述

二、实例1:CNN模型的转换

注:python 版本为3.10

2.1 pytorch模型训练

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = CNNModel().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(20):for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")

2.2 pth模型转onnx 并验证一致性

import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型并进行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True))  # 加载保存的模型权重
model.eval()  # 设置为评估模式# 创建一个示例输入
dummy_input = torch.randn(1, 1, 28, 28)  # MNIST 图像的形状# 使用 PyTorch 进行推理
with torch.no_grad():pytorch_output = model(dummy_input)# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")# 使用 ONNX 进行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')# 准备输入数据
onnx_input = dummy_input.numpy()  # 将 PyTorch 张量转换为 NumPy 数组
onnx_input = onnx_input.astype(np.float32)  # 确保数据类型为 float32# 使用 ONNX 进行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})# 比较输出
pytorch_output_np = pytorch_output.numpy()  # 将 PyTorch 输出转换为 NumPy 数组
onnx_output_np = onnx_output[0]  # ONNX 输出是一个列表,取第一个元素# 检查输出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):print("The outputs are consistent between PyTorch and ONNX.")
else:print("The outputs are NOT consistent between PyTorch and ONNX.")# 打印输出结果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659    0.5428004 -16.058285   -3.6684208  -4.596178-14.53585    -3.3159208  -5.7872214  -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658    0.5428015 -16.058285   -3.66842    -4.5961757-14.53585    -3.3159204  -5.787223   -5.3301597]]

2.3 onnx模型转tflite

参考这个项目:onnx2tflite

git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx

在这里插入图片描述

2.4 onnx模型和tflite一致性验证

import numpy as np
import onnxruntime as ort
import tensorflow as tf# 1. 加载 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)# 2. 加载 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()# 3. 准备输入数据
# 假设输入数据是 MNIST 数据集的一部分,形状为 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32)  # Keras 输入
input_data_onnx = input_data.transpose(0, 3, 1, 2)  # 转换为 ONNX 输入格式 (1, 1, 28, 28)# 4. 使用相同的输入数据进行推理# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()# 检查 TFLite 输入形状
print("TFLite Input Shape:", tflite_input_details[0]['shape'])# 设置 TFLite 输入
# 确保输入数据的形状与 TFLite 模型的输入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)# 5. 比较输出结果
# 计算输出的差异
onnx_difference = np.abs(onnx_output - tflite_output)# 输出结果
print("Difference (ONNX vs TFLite):", onnx_difference)# 检查是否一致
if np.all(onnx_difference < 1e-5):  # 设定一个阈值print("The outputs are consistent between ONNX and TFLite models.")
else:print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704  -6.5073314  -1.1807165  -2.4232314 -10.638929    2.2660115-4.5868526  -2.7494073  -0.5609715  -6.331989 ]]
TFLite Input Shape: [ 1 28 28  1]
TFLite Output: [[ -3.7372704   -6.5073323   -1.180716    -2.4232314  -10.6389282.2660117   -4.5868545   -2.7494078   -0.56097114  -6.331988  ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-072.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/85238.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4.2.5 Spark SQL 分区自动推断

在本节实战中&#xff0c;我们学习了Spark SQL的分区自动推断功能&#xff0c;这是一种提升查询性能的有效手段。通过创建具有不同分区的目录结构&#xff0c;并在这些目录中放置JSON文件&#xff0c;我们模拟了一个分区表的环境。使用Spark SQL读取这些数据时&#xff0c;Spar…

数据结构:导论

目录 什么是“第一性原理”&#xff1f; 什么是“数据结构”&#xff1f; 数据结构解决的根本问题是什么&#xff1f; 数据结构的两大分类 数据结构的基本操作 数据结构与算法的关系 学习数据结构的底层目标 什么是“第一性原理”&#xff1f; 在正式进入数据结构之前&…

汽车制造场景下Profibus转Profinet网关核心功能与应用解析

在当今工业自动化的浪潮中&#xff0c;各种通讯协议层出不穷&#xff0c;而其中PROFIBUS与PROFINET作为两种主流的工业通信标准&#xff0c;它们之间的转换需求日益增长。特别是对于那些希望实现老旧设备与现代化网络无缝对接的企业来说&#xff0c;一个高效、稳定的网关产品显…

qt ubuntu 20.04 交叉编译

一、交叉编译环境搭建 1.下载交叉编译工具链&#xff1a;https://developer.arm.com/downloads/-/gnu-a 可以根据自己需要下载对应版本&#xff0c;当前最新版本是10.3, 笔者使用10.3编译后的glibc.so版本太高&#xff08;glibc_2.3.3, glibc_2.3.4, glibc_2.3.5&#xff09;…

在Babylon.js中创建3D文字:简单而强大的方法

引言 在3D场景中添加文字是许多WebGL项目的常见需求。Babylon.js提供了多种创建3D文字的方法&#xff0c;其中使用TextBlock结合平面网格是一种简单而高效的方式。本文将介绍如何使用Babylon.js的GUI系统在3D空间中创建美观的文字效果。 方法概述 Babylon.js的GUI系统允许我…

油桃TV v20250519 一款电视端应用网站聚合TV播放器 支持安卓4.1

油桃TV v20250519 一款电视端应用网站聚合TV播放器 支持安卓4.1 应用简介&#xff1a; 油桃TV是一款开源电视端应用网站聚合浏览器&#xff0c;它把大家常见需求的一些网站都整合到了这个应用上&#xff0c;并进行了电视端…

Perl单元测试实战指南:从Test::Class入门到精通的完整方案

阅读原文 前言:为什么Perl开发者需要重视单元测试? "这段代码昨天还能运行,今天就出问题了!"——这可能是每位Perl开发者都经历过的噩梦。在没有充分测试覆盖的情况下,即使是微小的改动也可能导致系统崩溃。单元测试正是解决这一痛点的最佳实践,它能帮助我们在…

OpenCv高阶(十三)——人脸检测

文章目录 前言一、人脸检测—haar特征二、人脸检测---级联分类器1、级联分类器2、如何训练级联分类器3、已存在的级联分类器 三、代码分析1、人脸检测的简单使用2、人脸微笑检测&#xff08;1&#xff09; 初始化视频源&#xff08;2&#xff09;主循环处理每一帧&#xff08;3…

无线通信模块简介

QuecPython 是运行在无线通信模块上的开发框架。对于首次接触物联网开发的用户而言&#xff0c;无线通信模块可能是一个相对陌生的概念。本文主要针对无线通信和蜂窝网络本身&#xff0c;以及模块的概念、特性和开发方式进行简要的介绍。 无线通信和蜂窝网络 物联网对无线通信…

Unity 中实现首尾无限循环的 ListView

之前已经实现过&#xff1a; Unity 中实现可复用的 ListView-CSDN博客文章浏览阅读5.6k次&#xff0c;点赞2次&#xff0c;收藏27次。源码已放入我的 github&#xff0c;地址&#xff1a;Unity-ListView前言实现一个列表组件&#xff0c;表现方面最核心的部分就是重写布局&…

【C++】 类和对象(上)

1.类的定义 1.1类的定义格式 • class为定义类的关键字&#xff0c;后跟一个类的名字&#xff0c;{}中为类的主体&#xff0c;注意类定义结束时后⾯分号不能省 略。类体中内容称为类的成员&#xff1a;类中的变量称为类的属性或成员变量;类中的函数称为类的⽅法或 者成员函数。…

Transformer架构详解:从Attention到ChatGPT

Transformer架构详解&#xff1a;从Attention到ChatGPT 系统化学习人工智能网站&#xff08;收藏&#xff09;&#xff1a;https://www.captainbed.cn/flu 文章目录 Transformer架构详解&#xff1a;从Attention到ChatGPT摘要引言一、Attention机制&#xff1a;Transformer的…

Rock9.x(Linux)安装Redis7

&#x1f49a;提醒&#xff1a;1&#xff09;注意权限问题 &#x1f49a; 查是否已经安装了gcc gcc 是C语言编译器&#xff0c;Redis是用C语言开发的&#xff0c;我们需要编译它。 gcc --version如果没有安装gcc&#xff0c;那么我们手动安装 安装GCC sudo dnf -y install…

EasyExcel使用导出模版后设置 CellStyle失效问题解决

EasyExcel使用导出模版后在CellWriteHandler的afterCellDispose方法设置 CellStyle失效问题解决方法 问题描述&#xff1a;excel 模版塞入数据后&#xff0c;需要设置单元格的个性化设置时失效&#xff0c;本文以设置数据格式为例&#xff08;设置列的数据展示时需要加上千分位…

【Day41】

DAY 41 简单CNN 知识回顾 数据增强卷积神经网络定义的写法batch归一化&#xff1a;调整一个批次的分布&#xff0c;常用与图像数据特征图&#xff1a;只有卷积操作输出的才叫特征图调度器&#xff1a;直接修改基础学习率 卷积操作常见流程如下&#xff1a; 1. 输入 → 卷积层 →…

Express教程【002】:Express监听GET和POST请求

文章目录 2、监听post和get请求2.1 监听GET请求2.2 监听POST请求 2、监听post和get请求 创建02-app.js文件。 2.1 监听GET请求 1️⃣通过app.get()方法&#xff0c;可以监听客户端的GET请求&#xff0c;具体的语法格式如下&#xff1a; // 1、导入express const express req…

C# 文件 I/O 操作详解:从基础到高级应用

在软件开发中&#xff0c;文件操作&#xff08;I/O&#xff09;是一项基本且重要的功能。无论是读取配置文件、存储用户数据&#xff0c;还是处理日志文件&#xff0c;C# 都提供了丰富的 API 来高效地进行文件读写操作。本文将全面介绍 C# 中的文件 I/O 操作&#xff0c;涵盖基…

Vue-Router简版手写实现

1. 路由库工程设计 首先&#xff0c;我们需要创建几个核心文件来组织我们的路由库&#xff1a; src/router/index.tsRouterView.tsRouterLink.tsuseRouter.tsinjectionsymbols.tshistory.ts 2. injectionSymbols.ts 定义一些注入符号来在应用中共享状态&#xff1a; import…

Electron-vite【实战】MD 编辑器 -- 文件列表(含右键快捷菜单,重命名文件,删除本地文件,打开本地目录等)

最终效果 页面 src/renderer/src/App.vue <div class"dirPanel"><div class"panelTitle">文件列表</div><div class"searchFileBox"><Icon class"searchFileInputIcon" icon"material-symbols-light:…

Remote Sensing投稿记录(投稿邮箱写错、申请大修延期...)风雨波折投稿路

历时近一个半月&#xff0c;我中啦&#xff01; RS是中科院二区&#xff0c;2023-2024影响因子4.2&#xff0c;五年影响因子4.9。 投稿前特意查了下预警&#xff0c;发现近五年都不在预警名单中&#xff0c;甚至最新中科院SCI分区&#xff08;2025年3月&#xff09;在各小类上…