深度学习模型在C++平台的部署

一、概述

  深度学习模型能够在各种生产场景中发挥重要的作用,而深度学习模型往往在Python环境下完成训练,因而训练好的模型如何在生产环境下实现稳定可靠的部署,便是一个重要内容。C++开发平台广泛存在于各种复杂的生产环境,随着业务效能需求的不断提高,充分运用深度学习技术的优势显得尤为重要。本文介绍如何实现将深度学习模型部署在C++平台上。

二、步骤

  s1. Python环境中安装深度学习框架(如PyTorch、TensorFlow等);

  s2. P ython环境中设计并训练深度学习模型;

  s3. 将训练好的模型保存为.onnx格式的模型文件;

  s4. C++环境中安装Microsoft.ML.OnnxRuntime程序包;
(Visual Studio 2022中可通过项目->管理NuGet程序包完成快捷安装)

  s5. C++环境中加载模型文件,完成功能开发。

三、示例

  在Python环境下设计并训练一个关于手写数字识别的卷积神经网络(CNN)模型,将模型导出为ONNX格式的文件,然后在C++环境下完成对模型的部署和推理。

1. Python训练和导出

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.functional import F# 定义简单的CNN模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(32 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载训练数据
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')# 训练模型
train(model, train_loader, criterion, optimizer)# 导出为ONNX格式
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model,dummy_input,"mnist_model.onnx",export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)print("模型已成功导出为mnist_model.onnx")

在这里插入图片描述

2. C++ 部署和推理

#include <iostream>
#include <vector>
#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>int main() {// 初始化环境Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "MNIST");Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(1);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);// 加载模型std::wstring model_path = L"mnist_model.onnx";Ort::Session session(env, model_path.c_str(), session_options);// 准备输入std::vector<int64_t> input_shape = { 1, 1, 28, 28 };size_t input_tensor_size = 28 * 28;std::vector<float> input_tensor_values(input_tensor_size);// 读取测试图片cv::Mat test_image = cv::imread("test.jpg", cv::IMREAD_GRAYSCALE);         // 将Mat数据复制到vector中for (int i = 0; i < test_image.rows; ++i) {for (int j = 0; j < test_image.cols; ++j) {input_tensor_values[i * test_image.cols + j] = static_cast<float>(test_image.at<uchar>(i, j)); // 注意:uchar是unsigned char的缩写,表示无符号字符,通常用于存储灰度值}}// 创建输入张量auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_shape.data(), 4);// 设置输入输出名称std::vector<const char*> input_names;std::vector<const char*> output_names;input_names.push_back(session.GetInputNameAllocated(0, Ort::AllocatorWithDefaultOptions()).get());output_names.push_back(session.GetOutputNameAllocated(0, Ort::AllocatorWithDefaultOptions()).get());// 运行推理auto output_tensors = session.Run(Ort::RunOptions{ nullptr },input_names.data(),&input_tensor,1,output_names.data(),1);// 获取输出结果float* output = output_tensors[0].GetTensorMutableData<float>();std::vector<float> results(output, output + 10);// 找到预测的数字int predicted_digit = 0;float max_probability = results[0];for (int i = 1; i < 10; i++) {if (results[i] > max_probability) {max_probability = results[i];predicted_digit = i;}}std::cout << "预测结果: " << predicted_digit << std::endl;std::cout << "置信度分布:" << std::endl;for (int i = 0; i < 10; i++) {std::cout << "数字 " << i << ": " << results[i] << std::endl;}return 0;
}

测试图片:
在这里插入图片描述

程序运行:
在这里插入图片描述



End.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88249.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

若以部署在linux,nginx反向代理,登录404,刷新404问题

history模式在router下面的index.js文件的最下面history: createWebHistory(import.meta.env.VITE_APP_CONTEXT_PATH),这两个配置文件都加上然后nginx里面的配置是这个位置按照实际情况&#xff0c;我的是用docker挂载的&#xff0c;所以在/usr/share/nginx/html/lw-clothing为…

SQL Server通过存储过程实现HTML页面生成

引言在现代企业应用中&#xff0c;数据可视化是提升决策效率的关键。SQL Server作为核心数据库管理系统&#xff0c;不仅处理数据存储和查询&#xff0c;还具备强大的扩展能力。通过存储过程直接生成HTML页面&#xff0c;企业能减少对中间层&#xff08;如Web服务器或应用程序&…

qt绘制饼状图并实现点击即放大点击部分

做得比较low #ifndef TEST_POWER_H #define TEST_POWER_H#include <QWidget> #include <QtMath> #include <QPainter> #include <QPushButton> #include <QVector> #include <cmath>namespace Ui { class test_power; } struct PieData {Q…

HashMap的put、get方法详解(附源码)

put方法 HashMap 只提供了 put 用于添加元素&#xff0c;putVal 方法只是给 put 方法调用的一个方法&#xff0c;并没有提供给用户使用。 对 putVal 方法添加元素的分析如下&#xff1a;如果定位到的数组位置没有元素 就直接插入。如果定位到的数组位置有元素就和要插入的 key …

双立柱式带锯床cad【1张总图】+设计说明书+绛重

双立柱式带锯床 摘 要 随着机械制造技术的进步&#xff0c;制造业对于切割设备的精度、效率和稳定性要求越来越高。双立柱式带锯床作为一种重要的切割设备&#xff0c;必须能够满足工业生产对于高精度、高效率的需求。 双立柱式带锯床是一种重要的工业切割设备&#xff0c;其结…

在线JS解密加密配合ECC保护

在线JS解密加密配合ECC保护 1. ECC加密简介 定义 ECC&#xff08;Elliptic Curve Cryptography&#xff09;是一种基于椭圆曲线数学的公钥加密技术&#xff0c;利用椭圆曲线离散对数问题&#xff08;ECDLP&#xff09;实现高安全性。 背景 1985年&#xff1a;Koblitz&#xff0…

使用 Docker Compose 简化 INFINI Console 与 Easysearch 环境搭建

前言回顾 在上一篇文章《搭建持久化的 INFINI Console 与 Easysearch 容器环境》中&#xff0c;我们详细介绍了如何使用基础的 docker run 命令&#xff0c;手动启动和配置 INFINI Console (1.29.6) 和 INFINI Easysearch (1.13.0) 容器&#xff0c;并实现了关键数据的持久化&…

Word 怎么让段落对齐,行与行之间宽一点?

我们来分两步解决&#xff1a;段落对齐 和 调整行距。 这两个功能都集中在Word顶部的【开始】选项卡里的【段落】区域。 第一步&#xff1a;让段落对齐 “对齐”指的是段落的左右边缘如何排列。通常有四种方式。 操作方法&#xff1a;将鼠标光标点在你想修改的那个段落里的任意…

Attention机制完全解析:从原理到ChatGPT实战

一、Attention的本质与计算步骤 1.1 核心思想 动态聚焦&#xff1a;Attention是一种信息分配机制&#xff0c;让模型在处理输入时动态关注最重要的部分。类比&#xff1a;像人类阅读时用荧光笔标记关键句子。 1.2 计算三步曲&#xff08;以"吃苹果"为例&#xff09; …

2025年3月青少年电子学会等级考试 中小学生python编程等级考试三级真题答案解析(判断题)

博主推荐 所有考级比赛学习相关资料合集【推荐收藏】1、Python比赛 信息素养大赛Python编程挑战赛 蓝桥杯python选拔赛真题详解

HTML5 新特性详解:从语义化到多媒体的全面升级

很多小伙伴本都好奇&#xff1a;HTML5有什么功能是以前的HTML没有的&#xff1f; 今天就给大家说道说道 HTML5 作为 HTML 语言的新一代标准&#xff0c;带来了诸多革命性的新特性。这些特性不仅简化了前端开发流程&#xff0c;还大幅提升了网页的用户体验和功能性。本文将深入…

mac安装docker

1、下载docker-desktop https://www.docker.com/products/docker-desktop/2、安装&#xff0c;双击安装 3、优化docker配置 默认配置 cat ~/Library/Group\ Containers/group.com.docker/settings-store.json {"AutoStart": false,"DockerAppLaunchPath": …

mapbox进阶,绘制不随地图旋转的矩形,保证矩形长宽沿屏幕xy坐标方位

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️line线图层样式1.4 ☘️circle点图层样…

${project.basedir}延申出来的Maven内置的一些常用属性

如&#xff1a;${project.basedir} 是 Maven 的内置属性&#xff0c;可以被 pom.xml 直接识别。它表示当前项目的根目录&#xff08;即包含 pom.xml 文件的目录&#xff09;。 Maven 内置的一些常用属性&#xff1a; 项目相关&#xff1a; ${project.basedir} <!-- 项…

[特殊字符] Python 批量生成词云:读取词频 Excel + 自定义背景 + Excel to.png 流程解析

本文展示如何用 Python 从之前生成的词频 Excel 文件中读取词频数据&#xff0c;结合 wordcloud 和背景图&#xff0c;批量生成直观美观的词云图。适用于文本分析、内容展示、报告可视化等场景。 &#x1f4c2; 第一步&#xff1a;读取所有 Excel 词频文件 import os from ope…

模拟网络请求的C++类设计与实现

在C开发中&#xff0c;理解和模拟网络请求是学习客户端-服务器通信的重要一步。本文将详细介绍一个模拟HTTP网络请求的C类库设计&#xff0c;帮助开发者在不涉及实际网络编程的情况下&#xff0c;理解网络请求的核心概念和工作流程。 整体架构设计 这个模拟网络请求的类库主要由…

移动机器人的认知进化:Deepoc大模型重构寻迹本质

统光电寻迹技术已逼近物理极限。当TCRT5000传感器在强烈环境光下失效率超过37%&#xff0c;当PID控制器在路径交叉口产生63%的决策崩溃&#xff0c;工业界逐渐意识到&#xff1a;导引线束缚的不仅是车轮&#xff0c;更是机器智能的演化可能性。 ​技术破局点出现在具身认知架构…

记录一次pip安装错误OSError: [WinError 32]的解决过程

因为要使用 PaddleOCR&#xff0c;需要安装依赖。先通过 conda新建了虚拟环境&#xff0c;然后安装 PaddlePaddle&#xff0c;继续安装 PaddleOCR&#xff0c;上述过程我是在 VSCode的终端中处理&#xff0c;结果报错如下&#xff1a;Downloading multidict-6.6.3-cp312-cp312-…

后端id设置long类型时,传到前端,超过19位最后两位为00

文章目录一、前言二、问题描述2.1、问题背景2.2、问题示例三、解决方法3.1、将ID转换为字符串3.2、使用JsonSerialize注解3.3、使用JsonFormat注解一、前言 在后端开发中&#xff0c;我们经常会遇到需要将ID作为标识符传递给前端的情况。当ID为long类型时&#xff0c;如果该ID…

SpringAI学习笔记-MCP客户端简单示例

MCP客户端是AI与外部世界交互的桥梁。在AI系统中&#xff0c;大模型虽然具备强大的认知能力&#xff0c;却常常受限于数据孤岛问题&#xff0c;无法直接访问外部工具和数据源。MCP协议应运而生&#xff0c;作为标准化接口解决这一核心挑战。该协议采用客户端-服务端架构&#x…