C++ 部署LSTM(.onnx)

0、 背景

在工业自动化控制领域,预测某些变量是否关键。根据工厂的数据,训练好模型之后,将其转我通用的onnx 模型,并实现高效的推理。

模型训练

import numpy as np
from para import *
from data_utils import MyDataset
from data_utils import MyLoss
from torch import nn
import torch
import torch.optim as optim
from  torch.utils.data import DataLoader
from lstm_src_get_data import load_data
device = 'cpu'
num_epochs = g_num_epochs
mod_dir = './'
delay =g_delayclass RegLSTM(nn.Module):def __init__(self, inp_dim, out_dim, mid_dim, mid_layers):super(RegLSTM, self).__init__()self.rnn = nn.LSTM(inp_dim, mid_dim, mid_layers)  # rnnself.reg = nn.Sequential(nn.Linear(mid_dim, mid_dim),nn.Tanh(),nn.Linear(mid_dim, out_dim),)  # regressiondef forward(self, x):y = self.rnn(x)[0]  # y, (h, c) = self.rnn(x)seq_len, batch_size, hid_dim = y.shapey = y.view(-1, hid_dim)y = self.reg(y)y = y.view(seq_len, batch_size, -1)return y
# print(g_delay)model = RegLSTM(g_input_dim, g_output_dim, 100, 1).to(device)if g_is_load_model:model.load_state_dict(torch.load(g_Model_PATH))criterion =MyLoss() # 均方误差损失,用于回归问题
optimizer = optim.Adam(model.parameters(), lr=g_learning_rate)
data_len =   g_train_size+g_seq_len+g_delay*2
data = load_data(0, data_len, 1)
delay =g_delay
data_y_plt = data[delay:-(g_delay+g_seq_num),-1]
train_xs = None
train_ys = None
for i in range(0,g_seq_len*g_seq_num,1):begin_x = ibegin_y = i + delayend_x   = i + g_seq_lenend_y   = i + g_seq_len+delaydata_x = data[begin_x:end_x, :]  # delaydata_y = data[begin_y:end_y, -1]# print('data_y\n', data_y)train_size = len(data_x)train_x = data_x.reshape(-1, g_seq_len,g_input_dim)train_y = data_y.reshape(-1, g_seq_len,g_output_dim)# train_y = np.squeeze(train_y)if train_xs is None:train_xs = train_xtrain_ys = train_yelse:train_xs = np.concatenate((train_xs, train_x), axis=0)train_ys = np.concatenate((train_ys, train_y), axis=0)
dataset = MyDataset(train_xs, train_ys)
# 把 dataset 放入 DataLoader
BATCH_SIZE = g_BATCH_SIZE
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for epoch in range(num_epochs):loss = Nonefor batch_idx, (inputs, labels) in enumerate(dataloader):outputs = model(inputs)loss = criterion(outputs, labels.to(device))optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 更新参数if (epoch + 1) % 2 == 0:print(f'epoch [{epoch + 1}], Loss: {loss.item():.6f}')
torch.save(model.state_dict(), '{}/{}'.format(mod_dir,g_Model_PATH_s))
print("Save in:", '{}/{}'.format(mod_dir,g_Model_PATH_s))

2、模型导出


from torch import nn
import torch
from para import *
from models import RegLSTM
# 一个单词向量长度为10,隐藏层节点数为40,LSTM有1层
model = RegLSTM(g_input_dim, g_output_dim, 100, 1).to(g_device)
model.load_state_dict(torch.load(g_Model_PATH, map_location=g_device,weights_only=True))
# 2个句子组成,每个句子由5个单词,单词向量长度为10
input_data = torch.randn(2, 3, g_input_dim)
# 1-> LSTM层数*方向  2->batch  40-> 隐藏层节点数input_names = ["input"]
output_names = ["output"]
save_onnx_path= "./lstm_2_3.onnx"
torch.onnx.export(model,input_data,save_onnx_path,verbose=True,input_names=input_names,output_names=output_names,opset_version=12)

3 onnx 与 .pt 模型精度比较

模型转换为onnx 之后,可能存在精度损失,我们简单测试比较一下onnx 与 .pt 模型的精度。

3.1 .pt 模型运行结果

测试代码

# .pt 模型运行结果
import os, sys
import torch
import numpy as np
sys.path.append(os.getcwd())
import onnxruntime
from para import *
from PIL import Image
from models import RegLSTM
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 推理的图片路径
a = np.arange(48).reshape(2, 3, 8)
print(a)input_data = torch.tensor(a)
input_data=input_data.float()model = RegLSTM(g_input_dim, g_output_dim, 100, 1).to(g_device)
model.load_state_dict(torch.load(g_Model_PATH, map_location=g_device,weights_only=True))outputs = model(input_data)
print(outputs)

.pt 模型输出
在这里插入图片描述

3.2 .onnx模型运行结果

测试代码

# onnx 模型运行结果
import os, sys
import torch
import numpy as np
sys.path.append(os.getcwd())
import onnxruntime
from para import *
from PIL import Imagedef to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 推理的图片路径
a = np.arange(48).reshape(2, 3, 8)
print(a)input_data = torch.tensor(a)
input_data=input_data.float()
# 模型加载
onnx_model_path = "lstm_5_8.onnx"
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(input_data)}
outs = resnet_session.run(None, inputs)[0]
print(outs)

onnx 模型输出
在这里插入图片描述

3.3.C++ 版本 onnx模型运行结果

测试代码

#include <iostream>#include <iomanip>
using namespace std;
//#include <cuda_provider_factory.h>
#include <onnxruntime_cxx_api.h>
using namespace std;
using namespace Ort;const int batch_size = 2;
const int input_size = 8;
const int seq_len = 3;
const int output_size = 1;std::vector<float> testOnnxLSTM(std::vector<std::vector<std::vector<float>>>& inputs)
{//设置为VERBOSE,方便控制台输出时看到是使用了cpu还是gpu执行//Ort::Env env(ORT_LOGGING_LEVEL_VERBOSE, "test");Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(1); // 使用五个线程执行op,提升速度// 第二个参数代表GPU device_id = 0,注释这行就是cpu执行//OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);// const char* model_path = "../lstm.onnx";auto model_path = L"./lstm_2_3.onnx";//std::cout << model_path << std::endl;Ort::Session session(env, model_path, session_options);const char* input_names[] = { "input" };    // 根据上节输入接口名称设置const char* output_names[] = { "output" };  // 根据上节输出接口名称设置std::array<float, batch_size* seq_len* input_size> input_matrix;std::array<float, batch_size* seq_len* output_size> output_matrix;std::array<int64_t, 3> input_shape{ batch_size, seq_len, input_size };std::array<int64_t, 3> output_shape{ batch_size,seq_len, output_size };for (int i = 0; i < batch_size; i++)for (int j = 0; j < seq_len; j++)for (int k = 0; k < input_size; k++)input_matrix[i * seq_len * input_size + j * input_size + k] = inputs[i][j][k];Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());try{Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1);}catch (const std::exception& e){std::cout << e.what() << std::endl;}std::cout << "get result from LSTM onnx: \n";std::vector<float> ret;for (int i = 0; i < batch_size * seq_len * output_size; i++) {ret.emplace_back(output_matrix[i]);cout << setiosflags(ios::fixed) << setprecision(7) << output_matrix[i] << endl;std::cout << "\n";}cout << setiosflags(ios::fixed) << setprecision(7) << ret.back()<< endl;std::cout << "\n";return ret;
}int main()
{std::vector<std::vector<std::vector<float>>> data;int value = 0;for (int i = 0; i < batch_size; i++) {std::vector<std::vector<float>> t1;for (int j = 0; j < seq_len; j++) {std::vector<float> t2;for (int k = 0; k < input_size; k++) {t2.push_back(value++);}t1.push_back(t2);t2.clear();}data.push_back(t1);t1.clear();}std::cout << "data shape{batch ,seq dim}";std::cout << data.size() << " " << data[0].size() << " " << data[0][0].size() << std::endl;std::cout << "data" << std::endl;for (auto& i : data) {for (auto& j : i) {for (auto& k : j) {std::cout << k << "\t";}std::cout << "\n";}std::cout << "\n";}auto ret = testOnnxLSTM(data);return 0;
}

在这里插入图片描述

4、结果比较

输入
在这里插入图片描述
输出
在这里插入图片描述

可以看出 误差约为百万分之一

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/94841.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/94841.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习-卷积神经网络CNN-1×1卷积层

1x1卷积核&#xff0c;又称为网中网&#xff08;Network in Network&#xff09;&#xff1a;NIN卷积的本质是有效提取相邻像素间的相关特征&#xff0c;而11的卷积显然没有此作用。它失去了卷积层的特有能力——在高度和宽度维度上&#xff0c;识别相邻元素间相互作用的能力。…

使用 Python 异步爬虫抓取豆瓣电影Top250排行榜

导读 在现代网络爬虫开发中,面对 海量网页数据、慢速响应的网络接口,传统的同步爬虫方式已经不能满足高效抓取需求。本文将手把手带你构建一个 基于 aiohttp + asyncio 的异步爬虫系统,实战目标是抓取豆瓣电影 Top250 排行榜中的电影名称、评分和详情页地址。 目录 导读 …

云原生开发全面解析:从核心理念到实践挑战与解决方案

1. 云原生开发的核心理念与定义 云原生&#xff08;Cloud Native&#xff09;是一种基于云计算环境设计和运行应用程序的方法论&#xff0c;其三大技术支柱为容器化、微服务和声明式API。根据CNCF定义&#xff0c;云原生技术通过标准化接口和自动化管理&#xff0c;使应用具备…

WebForms 实例

WebForms 实例 引言 WebForms 是 ASP.NET 技术中的一种重要技术,它允许开发者以表单的形式创建动态网页。本文将通过具体的实例,深入探讨 WebForms 的基本概念、实现方法以及在实际项目中的应用。 WebForms 简介 WebForms 是一种用于创建动态网页的框架,它允许开发者以类…

Java 之 多态

一、多态 多态故名思义&#xff0c;多种状态。比如Animal 这个类中&#xff0c;eat 方法是 公共的方法&#xff0c;但是当 People&#xff0c;Dog , Cat,继承时&#xff0c;我们知道人要吃的是米饭&#xff0c;狗要吃的是狗粮&#xff0c;猫要吃的是猫粮。所以当不同类型的引用…

文件结构树的├、└、─ 符号

目录一、├、└、─符号的背景二、├、└、─ 符号的含义2.1 ├ 带竖线的分支符号2.2 └不带竖线的分支符号2.3 ─横线符号三、Windows系统中生成目录树一、├、└、─符号的背景 我们在编程中&#xff0c;可能会经常遇到一些特殊符号├、└、─。这并非偶然&#xff0c;二十由…

微软XBOX游戏部门大裁员

近日有报道称&#xff0c;微软正计划对Xbox游戏部门进行另外一次裁员&#xff0c;影响的将是整个团队&#xff0c;而不是特定岗位或者部门&#xff0c;大概10%至20%的Xbox团队成员受到影响&#xff0c;这是微软这次对Xbox业务重组的一部分。 据报道&#xff0c;微软已经开始新…

【关于Java 8 的新特性】

问&#xff1a;“Java 8 有啥新东西&#xff1f;” 你憋了半天&#xff0c;只说出一句&#xff1a;“嗯……有 Lambda 表达式。”别慌&#xff01;Java 8 可不只是“语法糖”那么简单。它是一次真正让 Java 从“老派”走向“现代” 的大升级&#xff01;一、Lambda 表达式&…

《嵌入式数据结构笔记(六):二叉树》

1. ​​树数据结构的基本定义和属性​​树是一种重要的非线性数据结构&#xff0c;用于表示层次关系。​​基本定义​​&#xff1a;树是由 n&#xff08;n ≥ 0&#xff09;个结点组成的有限集合。当 n 0 时&#xff0c;称为空树&#xff1b;当 n > 0 时&#xff0c;树必须…

sqlite的sql语法与技术架构研究

(Owed by: 春夜喜雨 http://blog.csdn.net/chunyexiyu) 参考&#xff1a;参考提示词与豆包AI交互输出内容。 sqlite作为最常用的本地数据库&#xff0c;其支持的sql语法也比较全面&#xff0c;历经了二十多年经久不衰&#xff0c;其技术架构设计也是非常优秀的。 一&#xff1a…

Javascript中的一些常见设计模式

1. 单例模式&#xff08;Singleton Pattern&#xff09; 核心思想 一个类只能有一个实例&#xff0c;并提供一个全局访问点。 场景 全局缓存Vuex / Redux 中的 store浏览器中的 localStorage 管理类 示例 const Singleton (function () {let instance;function createInstance…

2025 年最佳 AI 代理:工具、框架和平台比较

目录 什么是 AI Agents 应用 最佳 AI Agents&#xff1a;综合列表 LangGraph AutoGen CrewAI OpenAI Agents SDK Google Agent Development Kit (ADK) 最佳no-code和open-source AI Agents Dify AutoGPT n8n Rasa BotPress 最佳预构建企业 AI agents Devin AI …

Linux 学习 ------Linux 入门(上)

Linux 是一种自由和开放源代码的类 Unix 操作系统。它诞生于 1991 年&#xff0c;由芬兰程序员林纳斯・托瓦兹&#xff08;Linus Torvalds&#xff09;发起并开发。与 Windows 等闭源操作系统不同&#xff0c;Linux 的源代码是公开的&#xff0c;任何人都可以查看、修改和传播&…

[202403-E]春日

[202403-E]春日 题目背景 春水初至&#xff0c; 文笔亦似花开。 题目描述 坐看万紫千红&#xff0c; 提笔洋洋洒洒&#xff0c; 便成篇文章。 现在给你这篇文章&#xff0c; 这篇文章由若干个单词组成&#xff0c; 没有标点符号&#xff0c; 两两单词之间由一个空格隔开。 为了…

Unity笔记(三)——父子关系、坐标转换、Input、屏幕

写在前面写本系列的目的(自用)是回顾已经学过的知识、记录新学习的知识或是记录心得理解&#xff0c;方便自己以后快速复习&#xff0c;减少遗忘。这里只有部分语法知识。九、父子关系1、获取、设置父对象(1)获取父对象可以通过this.transform.parent获取当前对象的父对象Trans…

基于Dubbo的高并发服务治理与流量控制实战指南

基于Dubbo的高并发服务治理与流量控制实战指南 在微服务架构的大规模应用场景中&#xff0c;如何保证服务在高并发压力下的稳定与可用&#xff0c;是每位后端开发者必须面对的挑战。本文结合实际生产环境经验&#xff0c;分享基于Apache Dubbo的高并发服务治理与流量控制方案&a…

Mac 洪泛攻击笔记总结补充

一、Mac 洪泛攻击原理交换机依靠 MAC 地址表来实现数据帧的精准转发&#xff0c;该表记录着端口与相连主机 MAC 地址的对应关系。交换机具备自动学习机制&#xff0c;当收到一个数据帧时&#xff0c;会将帧中的源 MAC 地址与进入的端口号记录到 MAC 表中。同时&#xff0c;由于…

路由器不能上网的解决过程

情况 前段时间&#xff0c;公司来人弄了一下网络后&#xff0c;我的路由器就不能上网了&#xff0c;怎么回事啊。 先看看路由器的情况&#xff1a;看着网络是有连接的&#xff1a;看这上面是能上网的&#xff0c;但是网都是上不去。 奇怪&#xff01; 路由器介绍 路由器&#x…

Rancher 和 KubeSphere对比

以下是 Rancher 与 KubeSphere 的深度对比&#xff0c;涵盖核心定位、架构设计、功能模块、适用场景等关键维度&#xff0c;助您精准选型&#xff1a;一、核心定位与设计哲学维度RancherKubeSphere本质Kubernetes 多集群管理控制平面Kubernetes 全栈云原生操作系统目标简化K8s集…

【深度学习新浪潮】TripoAI是一款什么样的产品?

TripoAI是由硅谷AI初创公司VAST开发的多模态3D内容生成平台,其核心技术基于数十亿参数的3D基础模型,专注于通过文本描述、单图/多图输入或手绘涂鸦快速生成高精度可编辑的3D模型。以下是其核心信息: 一、技术架构与核心功能 秒级生成与多模态输入 生成速度:仅需8秒即可生成…