一、方案介绍
- 研发阶段:利用 PyTorch 的动态图特性进行快速原型验证,快速迭代模型设计。
- 灵活性与易用性:PyTorch 是一个非常灵活且易于使用的深度学习框架,特别适合研究和实验。其动态计算图特性使得模型的构建和调试变得更加直观,开发者可以在运行时修改模型结构。
- 快速原型开发:许多研究人员和开发者选择 PyTorch 进行模型训练,因为它支持快速原型开发和灵活的模型设计,能够快速验证新想法并进行迭代。
- 转换阶段:将训练好的模型通过 TorchScript 导出为 ONNX 格式,再转换为 TensorFlow 格式,最后生成 TFLite 模型。
- 专为移动和嵌入式设备优化:TensorFlow Lite 是专为移动和嵌入式设备设计的推理框架,能够在资源有限的环境中高效运行模型,确保在各种设备上实现实时推理。
- 支持模型量化和优化:TFLite 支持模型量化和优化,能够显著减小模型大小并提高推理速度,适合在手机、边缘设备等场景中使用。这使得开发者能够在不牺牲准确度的情况下,提升模型的运行效率。
- 部署阶段:将 TFLite 模型集成到 Android、iOS 或嵌入式系统中,确保模型能够在目标设备上高效运行。
- 内存和计算资源的优化:在推理阶段,使用 TFLite 可以减少内存占用和计算资源消耗,尤其是在移动设备和嵌入式系统上。这对于需要长时间运行的应用尤为重要,可以延长设备的电池寿命。
- 多种优化技术:TFLite 提供了多种优化技术,如模型量化(将浮点数转换为整数),可以进一步提高推理速度并降低功耗。这使得在实时应用中能够实现更快的响应时间,提升用户体验。
二、实例1:CNN模型的转换
注:python 版本为3.10
2.1 pytorch模型训练
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = CNNModel().to(device) # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(20):for images, labels in train_loader:images, labels = images.to(device), labels.to(device) # 将数据移动到 MPS 设备optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")
2.2 pth模型转onnx 并验证一致性
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型并进行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True)) # 加载保存的模型权重
model.eval() # 设置为评估模式# 创建一个示例输入
dummy_input = torch.randn(1, 1, 28, 28) # MNIST 图像的形状# 使用 PyTorch 进行推理
with torch.no_grad():pytorch_output = model(dummy_input)# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")# 使用 ONNX 进行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')# 准备输入数据
onnx_input = dummy_input.numpy() # 将 PyTorch 张量转换为 NumPy 数组
onnx_input = onnx_input.astype(np.float32) # 确保数据类型为 float32# 使用 ONNX 进行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})# 比较输出
pytorch_output_np = pytorch_output.numpy() # 将 PyTorch 输出转换为 NumPy 数组
onnx_output_np = onnx_output[0] # ONNX 输出是一个列表,取第一个元素# 检查输出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):print("The outputs are consistent between PyTorch and ONNX.")
else:print("The outputs are NOT consistent between PyTorch and ONNX.")# 打印输出结果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659 0.5428004 -16.058285 -3.6684208 -4.596178-14.53585 -3.3159208 -5.7872214 -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658 0.5428015 -16.058285 -3.66842 -4.5961757-14.53585 -3.3159204 -5.787223 -5.3301597]]
2.3 onnx模型转tflite
参考这个项目:onnx2tflite
git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx
2.4 onnx模型和tflite一致性验证
import numpy as np
import onnxruntime as ort
import tensorflow as tf# 1. 加载 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)# 2. 加载 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()# 3. 准备输入数据
# 假设输入数据是 MNIST 数据集的一部分,形状为 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32) # Keras 输入
input_data_onnx = input_data.transpose(0, 3, 1, 2) # 转换为 ONNX 输入格式 (1, 1, 28, 28)# 4. 使用相同的输入数据进行推理# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()# 检查 TFLite 输入形状
print("TFLite Input Shape:", tflite_input_details[0]['shape'])# 设置 TFLite 输入
# 确保输入数据的形状与 TFLite 模型的输入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)# 5. 比较输出结果
# 计算输出的差异
onnx_difference = np.abs(onnx_output - tflite_output)# 输出结果
print("Difference (ONNX vs TFLite):", onnx_difference)# 检查是否一致
if np.all(onnx_difference < 1e-5): # 设定一个阈值print("The outputs are consistent between ONNX and TFLite models.")
else:print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704 -6.5073314 -1.1807165 -2.4232314 -10.638929 2.2660115-4.5868526 -2.7494073 -0.5609715 -6.331989 ]]
TFLite Input Shape: [ 1 28 28 1]
TFLite Output: [[ -3.7372704 -6.5073323 -1.180716 -2.4232314 -10.6389282.2660117 -4.5868545 -2.7494078 -0.56097114 -6.331988 ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-072.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.