[模型部署] 1. 模型导出

👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:​
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!​
📁 收藏专栏即可第一时间获取最新推送🔔。​
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。​



人工智能

模型导出

本文介绍如何将深度学习模型导出为不同的部署格式,包括ONNX、TorchScript等,并对各种格式的优缺点和最佳实践进行总结,帮助你高效完成模型部署准备。


1. 导出格式对比

格式优点缺点适用场景
ONNX- 跨平台跨框架
- 生态丰富
- 标准统一
- 广泛支持
- 可能存在算子兼容问题
- 部分高级特性支持有限
- 跨平台部署
- 使用标准推理引擎
- 需要广泛兼容性
TorchScript- 与PyTorch无缝集成
- 支持动态图结构
- 调试方便
- 性能优化
- 仅限PyTorch生态
- 文件体积较大
- PyTorch生产环境
- 需要动态特性
- 性能要求高
TensorRT- 极致优化性能
- 支持GPU加速
- 低延迟推理
- 仅支持NVIDIA GPU
- 配置复杂
- 高性能推理场景
- 实时应用
- 边缘计算
TensorFlow SavedModel- TensorFlow生态完整支持
- 部署便捷
- 跨框架兼容性差- TensorFlow生产环境

2. ONNX格式导出

2.1 基本导出

ONNX格式适用于跨平台部署,支持多种推理引擎(如ONNXRuntime、TensorRT、OpenVINO等)。

import torch
import torch.onnxdef export_to_onnx(model, input_shape, save_path):# 设置模型为评估模式model.eval()# 创建示例输入dummy_input = torch.randn(input_shape)# 导出模型torch.onnx.export(model,               # 要导出的模型dummy_input,        # 模型输入save_path,          # 保存路径export_params=True, # 导出模型参数opset_version=11,   # ONNX算子集版本do_constant_folding=True,  # 常量折叠优化input_names=['input'],     # 输入名称output_names=['output'],   # 输出名称dynamic_axes={'input': {0: 'batch_size'},  # 动态批次大小'output': {0: 'batch_size'}})print(f"Model exported to {save_path}")# 使用示例
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
export_to_onnx(model, (1, 3, 224, 224), 'model.onnx')

2.2 验证导出模型

导出后必须进行全面验证,包括结构检查和数值对比:

  1. 结构验证
import onnx
import onnxruntime
import numpy as npdef verify_onnx_structure(onnx_path):# 加载并检查模型结构onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 打印模型信息print("模型输入:")for input in onnx_model.graph.input:print(f"- {input.name}: {input.type.tensor_type.shape}")print("\n模型输出:")for output in onnx_model.graph.output:print(f"- {output.name}: {output.type.tensor_type.shape}")
  1. 数值精度对比
def compare_outputs(model, onnx_path, input_data):# PyTorch结果model.eval()with torch.no_grad():torch_output = model(torch.from_numpy(input_data))# ONNX结果ort_output = verify_onnx_model(onnx_path, input_data)# 比较差异diff = np.abs(torch_output.numpy() - ort_output).max()print(f"最大误差: {diff}")return diff < 1e-5
  1. 验证 ONNX 模型
import onnx
import onnxruntime
import numpy as npdef verify_onnx_model(onnx_path, input_data):# 加载ONNX模型onnx_model = onnx.load(onnx_path)onnx.checker.check_model(onnx_model)# 创建推理会话ort_session = onnxruntime.InferenceSession(onnx_path)# 准备输入数据ort_inputs = {ort_session.get_inputs()[0].name: input_data}# 运行推理ort_outputs = ort_session.run(None, ort_inputs)return ort_outputs[0]input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = verify_onnx_model('model.onnx', input_data)

2.3 ONNX模型优化

使用ONNX Runtime提供的优化工具进一步提升性能:

import onnxruntime as ort
from onnxruntime.transformers import optimizerdef optimize_onnx_model(onnx_path, optimized_path):# 创建优化器配置opt_options = optimizer.OptimizationConfig(optimization_level=99,  # 最高优化级别enable_gelu_approximation=True,enable_layer_norm_optimization=True,enable_attention_fusion=True)# 优化模型optimized_model = optimizer.optimize_model(onnx_path, 'cpu',  # 或 'gpu'opt_options)# 保存优化后的模型optimized_model.save_model_to_file(optimized_path)print(f"优化后的模型已保存至 {optimized_path}")
  • optimizer.optimize_model() 第二个参数是优化目标设备,支持 ‘cpu’ 或 ‘gpu’。
    • 优化目标设备:指定模型优化时的目标硬件平台。例如:
      • ‘cpu’:针对 CPU 进行优化(如调整算子、量化参数等)。
      • ‘gpu’:针对 GPU 进行优化(如使用 CUDA 内核、张量核心等)。
        *运行时设备:优化后的模型可以在其他设备上运行,但性能可能受影响。例如:
      • 针对 CPU 优化的模型可以在 GPU 上运行,但可能无法充分利用 GPU 特性。
      • 针对 GPU 优化的模型在 CPU 上运行可能会报错或性能下降。
        建议保持优化目标与运行设备一致以获得最佳性能。

3. TorchScript格式导出

3.1 trace导出

适用于前向计算图结构固定的模型。

import torchdef export_torchscript_trace(model, input_shape, save_path):model.eval()example_input = torch.randn(input_shape)# 使用跟踪法导出traced_model = torch.jit.trace(model, example_input)traced_model.save(save_path)print(f"Traced model exported to {save_path}")return traced_model

3.2 script导出

适用于包含条件分支、循环等动态结构的模型。

import torch
import torch.nn as nn@torch.jit.script
class ScriptableModel(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 64, 3)self.relu = nn.ReLU()def forward(self, x):x = self.conv(x)x = self.relu(x)return xdef export_torchscript_script(model, save_path):scripted_model = torch.jit.script(model)scripted_model.save(save_path)print(f"Scripted model exported to {save_path}")return scripted_model

3.3 TorchScript模型验证

验证TorchScript模型的正确性:

def verify_torchscript_model(original_model, ts_model_path, input_data):# 原始模型输出original_model.eval()with torch.no_grad():original_output = original_model(input_data)# 加载TorchScript模型ts_model = torch.jit.load(ts_model_path)ts_model.eval()# TorchScript模型输出with torch.no_grad():ts_output = ts_model(input_data)# 比较差异diff = torch.abs(original_output - ts_output).max().item()print(f"最大误差: {diff}")return diff < 1e-5

4. 自定义算子处理

4.1 ONNX自定义算子

如需导出自定义算子,可通过ONNX扩展机制实现。

from onnx import helperdef create_custom_op():# 定义自定义算子custom_op = helper.make_node('CustomOp',           # 算子名称inputs=['input'],     # 输入outputs=['output'],   # 输出domain='custom.domain')return custom_opdef register_custom_op():# 注册自定义算子from onnxruntime.capi import _pybind_state as CC.register_custom_op('CustomOp', 'custom.domain')

4.2 TorchScript自定义算子

可通过C++扩展自定义TorchScript算子。

from torch.utils.cpp_extension import load# 编译自定义C++算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 在模型中使用自定义算子
class ModelWithCustomOp(nn.Module):def forward(self, x):return custom_op.forward(x)

4.3 自定义算子示例

下面是一个完整的自定义算子实现示例:

// custom_op.cpp
#include <torch/extension.h>torch::Tensor custom_forward(torch::Tensor input) {return input.sigmoid().mul(2.0);
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &custom_forward, "Custom forward function");
}
# 在Python中使用
import torch
from torch.utils.cpp_extension import load# 编译自定义算子
custom_op = load(name="custom_op",sources=["custom_op.cpp"],verbose=True
)# 测试自定义算子
input_tensor = torch.randn(2, 3)
output = custom_op.forward(input_tensor)
print(output)

5. 模型部署示例

5.1 ONNXRuntime部署

import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transformsdef preprocess_image(image_path, input_shape):# 图像预处理transform = transforms.Compose([transforms.Resize((input_shape[2], input_shape[3])),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image_tensor = transform(image).unsqueeze(0).numpy()return image_tensordef onnx_inference(onnx_path, image_path, input_shape=(1, 3, 224, 224)):# 加载ONNX模型session = ort.InferenceSession(onnx_path)# 预处理图像input_data = preprocess_image(image_path, input_shape)# 获取输入输出名称input_name = session.get_inputs()[0].nameoutput_name = session.get_outputs()[0].name# 执行推理result = session.run([output_name], {input_name: input_data})return result[0]

5.2 TorchScript部署

import torch
from PIL import Image
import torchvision.transforms as transformsdef torchscript_inference(model_path, image_path):# 加载TorchScript模型model = torch.jit.load(model_path)model.eval()# 图像预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载并处理图像image = Image.open(image_path).convert('RGB')input_tensor = transform(image).unsqueeze(0)# 执行推理with torch.no_grad():output = model(input_tensor)return output

6. 常见问题与解决方案

6.1 ONNX导出失败

问题: 导出ONNX时出现算子不支持错误

解决方案:

# 尝试使用更高版本的opset
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=13)# 或替换不支持的操作
class ModelWrapper(nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, x):# 替换不支持的操作为等效操作return self.model(x)

6.2 TorchScript跟踪失败

问题: 动态控制流导致trace失败

解决方案:

# 使用script而非trace
scripted_model = torch.jit.script(model)# 或修改模型结构避免动态控制流
class TraceFriendlyModel(nn.Module):def __init__(self, original_model):super().__init__()self.model = original_modeldef forward(self, x):# 移除动态控制流return self.model.forward_fixed(x)

6.3 推理性能问题

问题: 导出模型推理速度慢

解决方案:

# 1. 使用量化
from torch.quantization import quantize_dynamic
quantized_model = quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)# 2. 使用TensorRT优化ONNX
import tensorrt as trt
# TensorRT优化代码...# 3. 使用ONNX Runtime优化
import onnxruntime as ort
session = ort.InferenceSession("model.onnx", providers=['CUDAExecutionProvider'])

7. 最佳实践

  1. 选择合适的导出格式

    • ONNX:适合跨平台、跨框架部署,兼容性强
    • TorchScript:适合PyTorch生态内部署,支持灵活性高
    • 根据目标平台和性能需求选择
  2. 优化导出模型

    • 使用合适的opset版本(建议11及以上)
    • 启用常量折叠等优化选项
    • 导出后务必验证模型正确性
    • 考虑使用量化和剪枝优化模型大小
  3. 处理动态输入

    • 设置动态维度(如batch_size)
    • 测试不同输入大小,确保模型鲁棒性
    • 记录支持的输入范围和约束
  4. 文档和版本控制

    • 记录导出配置和依赖版本
    • 保存模型元数据(如输入输出规格)
    • 对模型文件进行版本化管理
    • 维护模型卡片(Model Card)记录关键信息
  5. 调试技巧

    • 使用ONNX Graph Viewer等可视化工具分析模型结构
    • 使用Netron查看计算图和参数分布
    • 比较原始与导出模型输出,检查数值精度差异
    • 遇到兼容性问题时查阅官方文档和社区经验

8. 参考资源

  • ONNX官方文档
  • PyTorch TorchScript教程
  • ONNX Runtime文档
  • TensorRT开发者指南
  • Netron模型可视化工具




📌 感谢阅读!若文章对你有用,别吝啬互动~​
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/81122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mac的Cli为什么输入python3才有用python --version显示无效,pyenv入门笔记,如何查看mac自带的标准库模块

根据你的终端输出&#xff0c;可以得出以下结论&#xff1a; 1. 你的 Mac 当前只有一个 Python 版本 系统默认的 Python 3 位于 /usr/bin/python3&#xff08;这是 macOS 自带的 Python&#xff09;通过 which python3 确认当前使用的就是系统自带的 Pythonbrew list python …

Java注解详解:从入门到实战应用篇

1. 引言 Java注解&#xff08;Annotation&#xff09;是JDK 5.0引入的一种元数据机制&#xff0c;用于为代码提供附加信息。它广泛应用于框架开发、代码生成、编译检查等领域。本文将从基础到实战&#xff0c;全面解析Java注解的核心概念和使用场景。 2. 注解基础概念 2.1 什…

前端方法的总结及记录

个人简介 &#x1f468;‍&#x1f4bb;‍个人主页&#xff1a; 魔术师 &#x1f4d6;学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全栈发展 &#x1f6b4;个人状态&#xff1a; 研发工程师&#xff0c;现效力于政务服务网事业 &#x1f1e8;&#x1f1f3;人生格言&…

组件导航 (HMRouter)+flutter项目搭建-混合开发+分栏效果

组件导航 (Navigation)flutter项目搭建 接上一章flutter项目的环境变量配置并运行flutter 1.flutter创建项目并运行 flutter create fluter_hmrouter 进入ohos目录打开编辑器先自动签名 编译项目-生成签名包 flutter build hap --debug 运行项目 HMRouter搭建安装 1.安…

城市排水管网流量监测系统解决方案

一、方案背景 随着工业的不断发展和城市人口的急剧增加&#xff0c;工业废水和城市污水的排放量也大量增加。目前&#xff0c;我国已成为世界上污水排放量大、增加速度快的国家之一。然而&#xff0c;总体而言污水处理能力较低&#xff0c;有相当部分未经处理的污水直接或间接排…

TCP/IP 知识体系

TCP/IP 知识体系 一、TCP/IP 定义 全称&#xff1a;Transmission Control Protocol/Internet Protocol&#xff08;传输控制协议/网际协议&#xff09;核心概念&#xff1a; 跨网络实现信息传输的协议簇&#xff08;包含 TCP、IP、FTP、SMTP、UDP 等协议&#xff09;因 TCP 和…

5G行业专网部署费用详解:投资回报如何最大化?

随着数字化转型的加速&#xff0c;5G行业专网作为企业提升生产效率、保障业务安全和实现智能化管理的重要基础设施&#xff0c;正受到越来越多行业客户的关注。部署5G专网虽然前期投入较大&#xff0c;但通过合理规划和技术选择&#xff0c;能够实现投资回报的最大化。 在5G行…

网页工具-OTU/ASV表格物种分类汇总工具

AI辅助下开发了个工具&#xff0c;功能如下&#xff0c;分享给大家&#xff1a; 基于Shiny开发的用户友好型网页应用&#xff0c;专为微生物组数据分析设计。该工具能够自动处理OTU/ASV_taxa表格&#xff08;支持XLS/XLSX/TSV/CSV格式&#xff09;&#xff0c;通过调用QIIME1&a…

【超分辨率专题】一种考量视频编码比特率优化能力的超分辨率基准

这是一个Benchmark&#xff0c;超分辨率视频编码&#xff08;2024&#xff09; 专题介绍一、研究背景二、相关工作2.1 SR的发展2.2 SR benchmark的发展 三、Benchmark细节3.1 数据集制作3.2 模型选择3.3 编解码器和压缩标准选择3.4 Benchmark pipeline3.5 质量评估和主观评价研…

保姆教程-----安装MySQL全过程

1.电脑从未安装过mysql的&#xff0c;先找到mysql官网&#xff1a;MySQL :: Download MySQL Community Server 然后下载完成后&#xff0c;找到文件&#xff0c;然后双击打开 2. 选择安装的产品和功能 依次点开“MySQL Servers”、“MySQL Servers”、“MySQL Servers 5.7”、…

【React中函数组件和类组件区别】

在 React 中,函数组件和类组件是两种构建组件的方式,它们在多个方面存在区别,以下详细介绍: 1. 语法和定义 类组件:使用 ES6 的类(class)语法定义,继承自 React.Component。需要通过 this.props 来访问传递给组件的属性(props),并且通常要实现 render 方法返回 JSX…

[基础] HPOP、SGP4与SDP4轨道传播模型深度解析与对比

HPOP、SGP4与SDP4轨道传播模型深度解析与对比 文章目录 HPOP、SGP4与SDP4轨道传播模型深度解析与对比第一章 引言第二章 模型基础理论2.1 历史演进脉络2.2 动力学方程统一框架 第三章 数学推导与摄动机制3.1 SGP4核心推导3.1.1 J₂摄动解析解3.1.2 大气阻力建模改进 3.2 SDP4深…

搭建运行若依微服务版本ruoyi-cloud最新教程

搭建运行若依微服务版本ruoyi-cloud 一、环境准备 JDK > 1.8MySQL > 5.7Maven > 3.0Node > 12Redis > 3 二、后端 2.1数据库准备 在navicat上创建数据库ry-seata、ry-config、ry-cloud运行SQL文件ry_20250425.sql、ry_config_20250224.sql、ry_seata_2021012…

Google I/O 2025 观看攻略一键收藏,开启技术探索之旅!

AIGC开放社区https://lerhk.xetlk.com/sl/1SAwVJ创业邦https://weibo.com/1649252577/PrNjioJ7XCSDNhttps://live.csdn.net/room/csdnnews/OOFSCy2g/channel/collectiondetail?sid2941619DONEWShttps://www.donews.com/live/detail/958.html凤凰科技https://flive.ifeng.com/l…

ORACLE 11.2.0.4 数据库磁盘空间爆满导致GAP产生

前言 昨天晚上深夜接到客户电话&#xff0c;反应数据库无法正常使用&#xff0c;想进入服务器检查时&#xff0c;登录响应非常慢。等两分钟后进入服务器且通过sqlplus进入数据库也很慢。通过检查服务器磁盘空间发现数据库所在区已经爆满&#xff0c;导致数据库在运行期间新增审…

计算机视觉---目标追踪(Object Tracking)概览

一、核心定义与基础概念 1. 目标追踪的定义 定义&#xff1a;在视频序列或连续图像中&#xff0c;对一个或多个感兴趣目标&#xff08;如人、车辆、物体等&#xff09;的位置、运动轨迹进行持续估计的过程。核心任务&#xff1a;跨帧关联目标&#xff0c;解决“同一目标在不同…

windows系统中下载好node无法使用npm

原因是 Windows PowerShell禁用导致的npm无法正常使用 解决方法管理员打开Windows PowerShell 输入Set-ExecutionPolicy -Scope CurrentUser RemoteSigned 按Y 确认就解决了

Nginx模块配置与请求处理详解

Nginx 作为模块化设计的 Web 服务器,其核心功能通过不同模块协同完成。以下是各模块的详细配置案例及数据流转解析: 一、核心模块配置案例 1. Handler 模块(内容生成) 功能:直接生成响应内容(如静态文件、重定向等) # 示例1:静态文件处理(ngx_http_static_module)…

Elasticsearch 学习(一)如何在Linux 系统中下载、安装

目录 一、Elasticsearch 下载二、使用 yum、dnf、zypper 命令下载安装三、使用 Docker 本地快速启动安装&#xff08;ESKibana&#xff09;【测试推荐】3.1 介绍3.2 下载、安装、启动3.3 访问3.4 修改配置&#xff0c;支持ip访问 官网地址&#xff1a; https://www.elastic.co/…

Java Map双列集合深度解析:HashMap、LinkedHashMap、TreeMap底层原理与实战应用

Java Map双列集合深度解析&#xff1a;HashMap、LinkedHashMap、TreeMap底层原理与实战应用 一、Map双列集合概述 1. 核心特点 键值对结构&#xff1a;每个元素由键&#xff08;Key&#xff09;和值&#xff08;Value&#xff09;组成。键唯一性&#xff1a;键不可重复&#…