onnx入门教程(二)—— PyTorch 转 ONNX 详解

  在这一节里,我们将详细介绍 PyTorch 到 ONNX 的转换函数—— torch.onnx.export。我们希望大家能够更加灵活地使用这个模型转换接口,并通过了解它的实现原理来更好地应对该函数的报错(由于模型部署的兼容性问题,部署复杂模型时该函数时常会报错)。

1.计算图导出方法

  TorchScript 是一种序列化和优化 PyTorch 模型的格式,在优化过程中,一个torch.nn.Module模型会被转换成 TorchScript 的 torch.jit.ScriptModule模型。现在, TorchScript 也被常当成一种中间表示使用。我们在其他文章中对 TorchScript 有详细的介绍(TorchScript 解读(一):初识 TorchScript - 知乎),这里介绍 TorchScript 仅用于说明 PyTorch 模型转 ONNX的原理。
torch.onnx.export中需要的模型实际上是一个torch.jit.ScriptModule。而要把普通 PyTorch 模型转一个这样的 TorchScript 模型,有跟踪(trace)和记录(script)两种导出计算图的方法。如果给torch.onnx.export传入了一个普通 PyTorch 模型(torch.nn.Module),那么这个模型会默认使用跟踪的方法导出。这一过程如下图所示:
在这里插入图片描述
回忆一下我们第一篇教程知识:跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);记录法则能通过解析模型来正确记录所有的控制流。我们以下面这段代码为例来看一看这两种转换方法的区别:

import torch class Model(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): for i in range(self.n): x = self.conv(x) return x models = [Model(2), Model(3)] 
model_names = ['model_2', 'model_3'] for model, model_name in zip(models, model_names): dummy_input = torch.rand(1, 3, 10, 10) dummy_output = model(dummy_input) model_trace = torch.jit.trace(model, dummy_input) model_script = torch.jit.script(model) # 跟踪法与直接 torch.onnx.export(model, ...)等价 torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output) # 记录法必须先调用 torch.jit.sciprt torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output) 

  在这段代码里,我们定义了一个带循环的模型,模型通过参数n来控制输入张量被卷积的次数。之后,我们各创建了一个n=2和n=3的模型。我们把这两个模型分别用跟踪和记录的方法进行导出。
值得一提的是,由于这里的两个模型(model_trace, model_script)是 TorchScript 模型,export函数已经不需要再运行一遍模型了。(如果模型是用跟踪法得到的,那么在执行torch.jit.trace的时候就运行过一遍了;而用记录法导出时,模型不需要实际运行)参数中的dummy_input和dummy_output`仅仅是为了获取输入和输出张量的类型和形状。
运行上面的代码,我们把得到的 4 个 onnx 文件用 Netron 可视化:
在这里插入图片描述
首先看跟踪法得到的 ONNX 模型结构。可以看出来,对于不同的 n,ONNX 模型的结构是不一样的。
在这里插入图片描述
而用记录法的话,最终的 ONNX 模型用 Loop 节点来表示循环。这样哪怕对于不同的 n,ONNX 模型也有同样的结构。
由于推理引擎对静态图的支持更好,通常我们在模型部署时不需要显式地把 PyTorch 模型转成 TorchScript 模型,直接把 PyTorch 模型用 torch.onnx.export 跟踪导出即可。了解这部分的知识主要是为了在模型转换报错时能够更好地定位问题是否发生在 PyTorch 转 TorchScript 阶段。

2.参数讲解

  了解完转换函数的原理后,我们来详细介绍一下该函数的主要参数的作用。我们主要会从应用的角度来介绍每个参数在不同的模型部署场景中应该如何设置,而不会去列出每个参数的所有设置方法。该函数详细的 API 文档可参考:

torch.onnx.export 在 torch.onnx.init.py文件中的定义如下:

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=True, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False): 

  前三个必选参数为模型、模型输入、导出的 onnx 文件名,我们对这几个参数已经很熟悉了。我们来着重看一下后面的一些常用可选参数。

  • export_params
    模型中是否存储模型权重。一般中间表示包含两大类信息:模型结构和模型权重,这两类信息可以在同一个文件里存储,也可以分文件存储。ONNX 是用同一个文件表示记录模型的结构和权重的。
    我们部署时一般都默认这个参数为 True。如果 onnx 文件是用来在不同框架间传递模型(比如 PyTorch 到 Tensorflow)而不是用于部署,则可以令这个参数为 False。
  • input_names, output_names
    设置输入和输出张量的名称。如果不设置的话,会自动分配一些简单的名字(如数字)。
    ONNX 模型的每个输入和输出张量都有一个名字。很多推理引擎在运行 ONNX 文件时,都需要以“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据。在进行跟张量有关的设置(比如添加动态维度)时,也需要知道张量的名字。
    在实际的部署流水线中,我们都需要设置输入和输出张量的名称,并保证 ONNX 和推理引擎中使用同一套名称。
  • opset_version
    转换时参考哪个 ONNX 算子集版本,默认为 9。后文会详细介绍 PyTorch 与 ONNX 的算子对应关系。
  • dynamic_axes
    指定输入输出张量的哪些维度是动态的。
    为了追求效率,ONNX 默认所有参与运算的张量都是静态的(张量的形状不发生改变)。但在实际应用中,我们又希望模型的输入张量是动态的,尤其是本来就没有形状限制的全卷积模型。因此,我们需要显式地指明输入输出张量的哪几个维度的大小是可变的。
    我们来看一个dynamic_axes的设置例子:
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) return x model = Model() 
dummy_input = torch.rand(1, 3, 10, 10) 
model_names = ['model_static.onnx',  
'model_dynamic_0.onnx',  
'model_dynamic_23.onnx'] dynamic_axes_0 = { 'in' : [0], 'out' : [0] 
} 
dynamic_axes_23 = { 'in' : [2, 3], 'out' : [2, 3] 
} torch.onnx.export(model, dummy_input, model_names[0],  
input_names=['in'], output_names=['out']) 
torch.onnx.export(model, dummy_input, model_names[1],  
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0) 
torch.onnx.export(model, dummy_input, model_names[2],  
input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23) 

  首先,我们导出 3 个 ONNX 模型,分别为没有动态维度、第 0 维动态、第 2 第 3 维动态的模型。
在这份代码里,我们是用列表的方式表示动态维度,例如:

dynamic_axes_0 = { 'in' : [0], 'out' : [0] 
} 

  由于在这份代码里我们没有更多的对动态维度的操作,因此简单地用列表指定动态维度即可。
之后,我们用下面的代码来看一看动态维度的作用:

import onnxruntime 
import numpy as np origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32) 
mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32) 
big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32) inputs = [origin_tensor, mult_batch_tensor, big_tensor] 
exceptions = dict() for model_name in model_names: for i, input in enumerate(inputs): try: ort_session = onnxruntime.InferenceSession(model_name) ort_inputs = {'in': input} ort_session.run(['out'], ort_inputs) except Exception as e: exceptions[(i, model_name)] = e print(f'Input[{i}] on model {model_name} error.') else: print(f'Input[{i}] on model {model_name} succeed.') 

  我们在模型导出计算图时用的是一个形状为(1, 3, 10, 10)的张量。现在,我们来尝试以形状分别是(1, 3, 10, 10), (2, 3, 10, 10), (1, 3, 20, 20)为输入,用ONNX Runtime运行一下这几个模型,看看哪些情况下会报错,并保存对应的报错信息。得到的输出信息应该如下:

Input[0] on model model_static.onnx succeed. 
Input[1] on model model_static.onnx error. 
Input[2] on model model_static.onnx error. 
Input[0] on model model_dynamic_0.onnx succeed. 
Input[1] on model model_dynamic_0.onnx succeed. 
Input[2] on model model_dynamic_0.onnx error. 
Input[0] on model model_dynamic_23.onnx succeed. 
Input[1] on model model_dynamic_23.onnx error. 
Input[2] on model model_dynamic_23.onnx succeed. 

  可以看出,形状相同的(1, 3, 10, 10)的输入在所有模型上都没有出错。而对于batch(第 0 维)或者长宽(第 2、3维)不同的输入,只有在设置了对应的动态维度后才不会出错。我们可以错误信息中找出是哪些维度出了问题。比如我们可以用以下代码查看input[1]在model_static.onnx中的报错信息:

print(exceptions[(1, 'model_static.onnx')]) # output 
# [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: in for the following indices index: 0 Got: 2 Expected: 1 Please fix either the inputs or the model. 

  这段报错告诉我们名字叫in的输入的第 0 维不匹配。本来该维的长度应该为 1,但我们的输入是 2。实际部署中,如果我们碰到了类似的报错,就可以通过设置动态维度来解决问题。

3.使模型在 ONNX 转换时有不同的行为

  有些时候,我们希望模型在导出至 ONNX 时有一些不同的行为模型在直接用 PyTorch 推理时有一套逻辑,而在导出的ONNX模型中有另一套逻辑。比如,我们可以把一些后处理的逻辑放在模型里,以简化除运行模型之外的其他代码。torch.onnx.is_in_onnx_export()可以实现这一任务,该函数仅在执行 torch.onnx.export()时为真。以下是一个例子:

import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) if torch.onnx.is_in_onnx_export(): x = torch.clip(x, 0, 1) return x 

  这里,我们仅在模型导出时把输出张量的数值限制在[0, 1]之间。使用 is_in_onnx_export确实能让我们方便地在代码中添加和模型部署相关的逻辑。但是,这些代码对只关心模型训练的开发者和用户来说很不友好,突兀的部署逻辑会降低代码整体的可读性。同时,is_in_onnx_export只能在每个需要添加部署逻辑的地方都“打补丁”,难以进行统一的管理。我们之后会介绍如何使用 MMDeploy 的重写机制来规避这些问题。

4.利用中断张量跟踪的操作

  PyTorch 转 ONNX 的跟踪导出法是不是万能的。如果我们在模型中做了一些很“出格”的操作,跟踪法会把某些取决于输入的中间结果变成常量,从而使导出的 ONNX 模型和原来的模型有出入。以下是一个会造成这种“跟踪中断”的例子:

class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): x = x * x[0].item() return x, torch.Tensor([i for i in x]) model = Model()       
dummy_input = torch.rand(10) 
torch.onnx.export(model, dummy_input, 'a.onnx') 

  如果你尝试去导出这个模型,会得到一大堆 warning,告诉你转换出来的模型可能不正确。这也难怪,我们在这个模型里使用了.item()把 torch 中的张量转换成了普通的 Python 变量,还尝试遍历 torch 张量,并用一个列表新建一个 torch 张量。这些涉及张量与普通变量转换的逻辑都会导致最终的 ONNX 模型不太正确。
另一方面,我们也可以利用这个性质,在保证正确性的前提下令模型的中间结果变成常量。这个技巧常常用于模型的静态化上,即令模型中所有的张量形状都变成常量。在未来的教程中,我们会在部署实例中详细介绍这些“高级”操作。

5.PyTorch 对 ONNX 的算子支持

  在确保torch.onnx.export()的调用方法无误后,PyTorch 转 ONNX 时最容易出现的问题就是算子不兼容了。这里我们会介绍如何判断某个 PyTorch 算子在 ONNX 中是否兼容,以助大家在碰到报错时能更好地把错误归类。而具体添加算子的方法我们会在之后的文章里介绍。
在转换普通的torch.nn.Module模型时,PyTorch 一方面会用跟踪法执行前向推理,把遇到的算子整合成计算图;另一方面,PyTorch 还会把遇到的每个算子翻译成 ONNX 中定义的算子。在这个翻译过程中,可能会碰到以下情况:
该算子可以一对一地翻译成一个 ONNX 算子。
该算子在 ONNX 中没有直接对应的算子,会翻译成一至多个 ONNX 算子。
该算子没有定义翻译成 ONNX 的规则,报错。
那么,该如何查看 PyTorch 算子与 ONNX 算子的对应情况呢?由于PyTorch 算子是向 ONNX 对齐的,这里我们先看一下 ONNX 算子的定义情况,再看一下PyTorch 定义的算子映射关系。

6.ONNX 算子文档

  ONNX 算子的定义情况,都可以在官方的算子文档中查看。这份文档十分重要,我们碰到任何和 ONNX 算子有关的问题都得来”请教“这份文档。
在这里插入图片描述
这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。
在这里插入图片描述
通过点击表格中的链接,我们可以查看某个算子的输入、输出参数规定及使用示例。比如上图是 Relu 在 ONNX 中的定义规则,这份定义表明 Relu 应该有一个输入和一个输入,输入输出的类型相同,均为 tensor。

7.PyTorch 对 ONNX 算子的映射

  在 PyTorch 中,和 ONNX 有关的定义全部放在 torch.onnx目录中,如下图所示:
在这里插入图片描述
其中,symbolic_opset{n}.py(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。我们之前讲过, bicubic 插值是在第 11 个版本开始支持的。我们以它为例来看看如何查找算子的映射情况。
首先,使用搜索功能,在torch/onnx文件夹搜索"bicubic",可以发现这个这个插值在第 11 个版本的定义文件中:
在这里插入图片描述
之后,我们按照代码的调用逻辑,逐步跳转直到最底层的 ONNX 映射函数:

upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic") -> def _interpolate(name, dim, interpolate_mode): return sym_help._interpolate_helper(name, dim, interpolate_mode) -> def _interpolate_helper(name, dim, interpolate_mode): def symbolic_fn(g, input, output_size, *args): ... return symbolic_fn 

  最后,在symbolic_fn中,我们可以看到插值算子是怎么样被映射成多个 ONNX 算子的。其中,每一个g.op就是一个 ONNX 的定义。比如其中的 Resize 算子就是这样写的:

return g.op("Resize", input, empty_roi, empty_scales, output_size, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75,  # only valid when mode="cubic" mode_s=interpolate_mode,  # nearest, linear, or cubic nearest_mode_s="floor")  # only valid when mode="nearest" 

  通过在前面提到的ONNX 算子文档中查找 Resize 算子的定义,我们就可以知道这每一个参数的含义了。用类似的方法,我们可以去查询其他 ONNX 算子的参数含义,进而知道 PyTorch 中的参数是怎样一步一步传入到每个 ONNX 算子中的。
掌握了如何查询 PyTorch 映射到 ONNX 的关系后,我们在实际应用时就可以在 torch.onnx.export()的opset_version中先预设一个版本号,碰到了问题就去对应的 PyTorch 符号表文件里去查。如果某算子确实不存在,或者算子的映射关系不满足我们的要求,我们就可能得用其他的算子绕过去,或者自定义算子了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94099.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94099.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式LINUX——————网络TCP

一、TCP连接1.TCP特点:(1)面向链接(2)面向字节流(3)安全可靠的传输协议,因为会先建立连接(4)占用资源开销大,效率低,实时性不佳&#…

alicloud 阿里云有哪些日志 审计日志

1: 阿里有哪些audit log: Audit Related Logs Below table describe the logs available in Log Service that might be applicable to the Security Operations Team. 2: 怎么来分析呢? Overview Its recommended to built a program with SLS Consumer Group which real…

如何理解AP服务发现协议中“如果某项服务需要被配置为可通过多个不同的网络接口进行访问,则应为每个网络接口使用一个独立的客户端服务实例”?

上一句:[PRS_SOMEIPSD_00238]◎ 「如果某项服务需要在多个网络接口上提供,则应为每个网络接口使用一个独立的服务器服务实例。」(RS_SOMEIPSD_00003) 本句:[PRS_SOMEIPSD_00239] 「如果某项服务需要被配置为可通过多个不同的网络接口进行访问…

piecewise jerk算法介绍

piecewise jerk算法介绍 piecewise jerk算法是百度Apollo中的一种用于路径和速度平滑的算法,该算法假设相邻点之间的jerk为常数,基于该假设将平滑问题构建为二次规划问题,调用osqp求解器求解。参考论文为:Optimal Vehicle Path Pl…

分布式蜜罐系统的部署安装

前阵子勒索病毒泛滥,中小企业由于缺少专业EDR,态势感知,IPS等设备,往往是在勒索事件发生之后才后知后觉,也因为缺乏有效的备份策略,导致数据,经济,商业信誉的丧失,甚至还…

定时器互补PWM输出和死区

定时器互补PWM输出和死区互补PWM(Complementary PWM)H桥、全桥、半桥中的应用为什么需要死区时间互补PWM(Complementary PWM) 是一种特殊的 PWM 输出模式,通常用于H桥、全桥或半桥电路的驱动。其核心原理是利用定时器…

嵌入式ARM程序高级调试基础:8.QEMU ARM虚拟机与tftp配置

嵌入式ARM程序高级调试基础:8.QEMU ARM虚拟机与tftp配置 文章目录 嵌入式ARM程序高级调试基础:8.QEMU ARM虚拟机与tftp配置 一.总的网络配置过程 二.主机配置 三.QEMU ARM 网络配置 四.主机与虚拟器之间的网络测试 五.TFTP网络配置 5.1 ubuntu主机安装tftp服务器 5.2 设置tft…

【贪心算法】贪心算法六

贪心算法六 1.坏了的计算器 2.合并区间 3.无重叠区间 4.用最少数量的箭引爆气球 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃😃 1.坏了的计算器 题目链接: 991. 坏了的计算器 题目分析: 算法原理: 解法一:正向推导 以3转化…

直播预约 | CATIA MODSIM SmartCAE带练营第3期:让每轮设计迭代都快人一步!

▼▼免费报名链接▼▼ 达索系统企业数字化转型在线研讨会https://3ds.tbh5.com/EventDetail.aspx?eid1195&frpt 迅筑官网 ​​

OSI参考模型TCP/IP模型 二三事

计算机网络的学习离不开OSI参考模型&TCP/IP模型对各层功能与任务的了解就是学习的主要内容其二者的区别也是我们应该了解的其中 拥塞控制和流量控制 就是各层功能中 两个易混淆的概念流量控制(Flow Control):解决的是发送方和接收方之间速…

DataStream实现WordCount

目录读取文本数据读取端口数据事实上Flink本身是流批统一的处理架构,批量的数据集本质上也是流,没有必要用两套不同的API来实现。所以从Flink 1.12开始,官方推荐的做法是直接使用DataStream API,在提交任务时通过将执行模式设为BA…

imx6ull-驱动开发篇37——Linux MISC 驱动实验

目录 MISC 设备驱动 miscdevice结构体 misc_register 函数 misc_deregister 函数 实验程序编写 修改设备树 驱动程序编写 miscbeep.c miscbeepApp.c Makefile 文件 运行测试 MISC 驱动也叫做杂项驱动,也就是当某些外设无法进行分类的时候就可以使用 MISC…

C# 项目“交互式展厅管理客户端“针对的是“.NETFramework,Version=v4.8”,但此计算机上没有安装它。

C# 项目“交互式展厅管理客户端"针对的是".NETFramework,Versionv4.8”,但此计算机上没有安装它。 解决方法: C# 项目“交互式展厅管理客户端"针对的是".NETFramework,Versionv4.8”,但此计算机上没有安装它。 下载地址…

FFmpeg及 RTSP、RTMP

FFmpeg 是一个功能强大的跨平台开源音视频处理工具集 ,集录制、转码、编解码、流媒体传输等功能于一体,被广泛应用于音视频处理、直播、点播等场景。它支持几乎所有主流的音视频格式和协议,是许多媒体软件(如 VLC、YouTube、抖音等…

金山办公的服务端开发工程师-25届春招笔试编程题

1.作弊 溪染:六王毕,四海一;蜀山兀,阿房出。覆压三百余里,隔离天日。骊山北构而西折,直走咸阳。二川溶溶,流入宫墙。五步一楼,十步一阁;廊腰缦回,檐牙高啄&am…

注意力机制中为什么q与k^T相乘是注意力分数

要理解 “qkT\mathbf{q} \times \mathbf{k}^TqkT 是注意力分数”,核心是抓住注意力机制的本质目标 ——量化 “查询(q)” 与 “键(k)” 之间的关联程度,而向量点积(矩阵相乘的元素本质&#xff…

Krea Video:Krea AI推出的AI视频生成工具

本文转载自:Krea Video:Krea AI推出的AI视频生成工具 - Hello123工具导航 ** 一、平台定位与技术特性 Krea Video 是 Krea AI 推出的 AI 视频生成工具,通过结合关键帧图像与文本提示实现精准视频控制。用户可自定义视频首尾帧、为每张图片设…

C++初阶(2)C++入门基础1

C是在C的基础之上,容纳进去了面向对象编程思想,并增加了许多有用的库,以及编程范式 等。熟悉C语言之后,对C学习有一定的帮助。 本章节主要目标: 补充C语言语法的不足,以及C是如何对C语言设计不合理的地方…

ANSI终端色彩控制知识散播(II):封装的层次(Python)——不同的逻辑“一样”的预期

基础高阶各有色,本原纯真动乾坤。 笔记模板由python脚本于2025-08-22 18:05:28创建,本篇笔记适合喜欢终端色彩ansi编码和python的coder翻阅。 学习的细节是欢悦的历程 博客的核心价值:在于输出思考与经验,而不仅仅是知识的简单复述…

前端无感刷新 Token 的 Axios 封装方案

在现代前端应用中,基于 Token 的身份验证已成为主流方案。然而,Token 过期问题常常困扰开发者 —— 如何在不打断用户操作的情况下自动刷新 Token,实现 "无感刷新" 体验?本文将详细介绍基于 Axios 的解决方案。什么是无…