Cap2:Pytorch转TensorRT(上:Pytorch->ONNX)

文章目录

  • 1、pytorch导出onnx模型
  • 2、使用onnxruntime推理onnx模型
  • 3、精度对齐
  • 4、总结

深度学习框架种类繁多,想实现任意框架之间的模型转换是一件困难的事情。但现在有一个中间格式ONNX,任何框架模型都支持转为ONNX,然后也支持从ONNX转为自身框架,那么每一种框架都只需维护如何ONNX进行转换即可,大大降低了维护成本,也给使用的开发者带来遍历。

在这里插入图片描述
如图所示,中间件不止一种,但是ONNX是使用最广泛的一种。需要注意的是,中间件只是一个描述格式,比如resnet18.onnx这个从pytorch导出的onnx中间件中描述了每一个算子的属性,算子中的权重数值,算子之间的运算图。我们需要一个引擎或框架转为自身特定的格式后,再使用推理功能进行推理。(类似一张jpeg图片在电脑中只是一串数值,必须借用图片解析工具如照片浏览器等软件解析jpeg的数值后才能显示。所以需要区分数据和数据能实现的功能是两回事)

下面我们尝试从pytorch导出resnet18为onnx中间件,然后使用onnxruntime(和onnx不一样,这是一个推理引擎,只不过是专门针对onnx格式的,所以支持直接加载onnx模型推理。如果使用tensorrt推理引擎,则需要将onnx转为tensrort支持的格式后再加载推理)进行推理,并检查使用pytorch推理的结果和onnxruntime推理的结果之间的差异有多大。

1、pytorch导出onnx模型

首先需要安装:

pip install onnx  

pytorch提供了对onnx的支持,现在让我们尝试将pytorch的resnet18导出为onnx模型。

import torch
import torchvision# 创建一个符合输入shape的假数据
dummy_input = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cuda:0")
resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).cuda()input_names = ["input:0"]  # 给输入节点取个名字,会伴随后续流程,方面我们定位输入节点
output_names = ["output:0"] # 同上
export_path = "resnet18.onnx"  # 保存路径# 导出
with torch.no_grad():torch.onnx.export(model=resnet18,args=dummy_input,f=export_path,opset_version=11,verbose=True,input_names=input_names,output_names=output_names)

在使用export函数时,我们手动定义了一个假数据,其shape和resnet18要求的输入一致。pytorch转onnx的原理不是通过分析语法进行转换的,而是让pytorch运行一次,然后通过追踪数据流,得到该组输入对应的计算图。 所以就是要给定一个输入,执行一遍模型,把对应的计算图记录下来,保存为onnx格式。export函数就是用的该种追踪方式导出的。(一个问题是如果模型中存在控制流,本次输入走分支1,那就只会记录只有分支1的模型,所以含有控制流的模型不能简单使用这种方法),所以这也就是为什么需要给定一组输入,同时为什么要with torch.no_grad了,因为我们只是追踪计算图而不需要梯度。

初次之外还定义了input_names和output_names,两个都是列表,因为某些模型不止一个输入输出。这里给每一个输入输出定义一个名字,后续我们可以根据名字直接定位到输入输出节点,获取数据。特别的如果想获取中间节点数据,也可以通过其名字定位节点获取数据。

opset_version是定义算子集。ONNX给每一个算子都有一个名字,比如ONNX定义了卷积名字为Conv,其他框架看到Conv开头的数据定义,明白这是一个卷积层, 会使用自身的卷积算子去映射,实现转换。但是深度学习各种算法更新很快,所以ONNX与时俱进会不断新增新的算子,你可以通过Operator Schemas查看到是否有自己想要得算子,以及它出现在那个版本。

2、使用onnxruntime推理onnx模型

当我们将pt模型转为onnx模型后,使用onnxruntime引擎进行推理。
首先需要安装推理引擎

Pip install onnxruntime  # 支持onnx中间件的推理引擎

完整的推理代码如下:

import cv2
import torchvision
import numpy as np
import onnxruntime# 获取ImageNet1K的标签
labels = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.value.meta["categories"]
# 使用InferenceSession创建一个引擎以供后续推理,CPUExecutionProvider指定使用CPU
session = onnxruntime.InferenceSession("resnet18.onnx", providers=['CPUExecutionProvider'])def preprocess(cvimg):# 将图像大小调整为 (224, 224)image_resized = cv2.resize(cvimg, (224, 224))# 将图像转换为 NumPy 数组,并将像素值缩放到 [0, 1]image_normalized = image_resized / 255.0# 对图像进行归一化mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image_normalized = (image_normalized - mean) / std# 调整图像维度顺序,OpenCV 默认通道顺序为 BGR,而 PyTorch 默认为 RGBreturn np.transpose(image_normalized, (2, 0, 1)).astype(np.float32)img = cv2.imread(img_path)
inp = preprocess(img)  # 预处理,得到shape=(3,224,224)的图片
inp = np.expand_dims(inp, axis=0)  # 升维得到shape=(1,3,224,224),BS和导出ONNX定义的假数据的BS一致# run函数用于推理,output_names指明了要获取哪个节点的数据,input_feed使用字典给所有的输入指定输入数据
output = session.run(output_names=["output:0"], input_feed={"input:0": inp}) # output是列表,其中是节点"output:0"的数据,shape=(1,1000)

实际上output列表中的一个元素就是一个节点的输出,如下图如果定义两个输出节点,表示我要获取这两个节点的数据。那么output列表长度就是节点数量。
在这里插入图片描述

3、精度对齐

当我们将pt模型转为onnx模型后,使用onnxruntime引擎进行推理。考虑到1)模型数据转换时存在误差;2)不同框架(pt和onnxruntime)对同一个算子的实现也可能存在差异;3)图像预处理实现方法不同等原因,两种框架的推理结果是有一定差异的。但是只要这个差异在容忍的范围内,都是可以接受的。

下面我们同时使用pytorch框架和onnxruntime框架对pt模型和onnx模型进行推理。onnxruntime的输入输出使用numpy的ndarray数据。

为了简明,使用numpy产生随机数据,然后同时给pytorch(转tensor)和onnxruntime,保证了输入的一致性。我们期望两个框架的输出也应该尽可能的相似。

下面的代码将随机生成20个假数据,然后pt和onnxruntime分别进行运算,并将结果储存在pt_output_array和onnx_output_array,最后计算这两个结果集的差异。

import torch
import torchvision
import numpy as np
import onnxruntimeresnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
labels = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.value.meta["categories"]
resnet18.eval()session = onnxruntime.InferenceSession("resnet18.onnx", providers=['CPUExecutionProvider'])# 创建ndarray用于储存输入和输出数据
pt_inp_array = np.zeros([20, 1, 3, 224, 224], dtype=np.float32)
onnx_inp_array = np.zeros([20, 1, 3, 224, 224], dtype=np.float32)pt_output_array = np.zeros([20, 1, 1000], dtype=np.float32)
onnx_output_array = np.zeros([20, 1, 1000], dtype=np.float32)for idx in range(20):dummy_input = np.random.random([1, 3, 224, 224]).astype(np.float32)  # 随机生成一个假数据pt_inp_array[idx] = dummy_input  # 将pytorch的输入记录with torch.no_grad():output = resnet18(torch.tensor(dummy_input))pt_output_array[idx] = output.cpu().detach().numpy()  # 将pytorch的输出记录onnx_inp_array[idx] = dummy_input  # 将onnx的输入记录output = session.run(output_names=["output:0","output:0"], input_feed={"input:0": dummy_input})onnx_output_array[idx] = output[0]  # 将onnx的输出记录# 检查pt和onnx输入的差异,因为是相同的,所以没有差异
np.testing.assert_allclose(pt_inp_array, onnx_inp_array, rtol=1e-10, atol=1e-10)# 检查pt和onnx输出的差异,经过测试在1e-06等级左右
np.testing.assert_allclose(pt_output_array, onnx_output_array, rtol=1e-10, atol=1e-10)

结果如图所示。np.testing.assert_allclose(pt_inp_array, onnx_inp_array, rtol=1e-10, atol=1e-10)是没有报错的,因为实际上输入是一样的,所以差异为0。但是输出报错了,总共有20000个参数,其中有19244个参数差异不满足设定的1e-10。实际上这里设定的阈值很小,一般设定在1e-5(不是绝对的,根据任务调整)。
在这里插入图片描述
设定阈值为1e-5后:np.testing.assert_allclose(pt_output_array, onnx_output_array, rtol=1e-5, atol=1e-5),不再报错,说明精度基本对齐。

4、总结

在本文中我们使用pytorch自带的工具将pt模型转为onnx模型,并使用onnxruntime推理引擎进行推理。为了保证模型转换过程中精度,进行了精度对齐的小实验,证明转换前后的误差在1e-6这个级别,是可以忍受的。

后续继续将onnx转为tensorRT进行部署,实现从pt–onnx–tensorRT这个部署路线。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746780.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

案例分析篇00-【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-先导篇)

专栏系列文章: 2024高级系统架构设计师备考资料(高频考点&真题&经验)https://blog.csdn.net/seeker1994/category_12593400.html 案例分析篇01:软件架构设计考点架构风格及质量属性 案例分析篇11:UML设计考…

Java 常用注解

一、较不熟悉 1、@MappedJdbcTypes(JdbcType.VARCHAR) MyBatis 框架中的一个注解,用于指定某个字段或方法参数与数据库中的 JDBC 类型之间的映射关系。通常作用在实体类属性或者参数上。 如下标识username字段映射到数据库中的VARCHAR属性。 public interface UserMapper {@Se…

KY199 查找

描述: 输入数组长度 n 输入数组 a[1…n] 输入查找个数m 输入查找数字b[1…m] 输出 YES or NO 查找有则YES 否则NO 。 输入描述: 输入有多组数据。 每组输入n,然后输入n个整数,再输入m,然后再输入m个整数(1&…

疫情网课管理系统|基于springboot框架+ Mysql+Java+Tomcat的疫情网课管理系统设计与实现(可运行源码+数据库+设计文档+部署说明)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 ​编辑 学生功能模块 管理员功能 教师功能模块 系统功能设计 数据库E-R图设计 lun…

人工智能入门之旅:从基础知识到实战应用(五)

一、人工智能实战项目与案例分析 1. AI入门项目电影评论情感分析实战 典型的AI入门实战项目,比如电影评论情感分析,是一个非常适合初学者的项目,因为它简单易懂,同时涵盖了自然语言处理(NLP)领域的一些基…

Ubuntu上搭建TFTP服务

Ubuntu上搭建TFTP服务 TFTP服务简介搭建TFTP服务安装TFTP服务修改配置文件 重启服务 TFTP服务简介 TFTP是一个基于UDP协议实现的用于在客户机和服务器之间进行简单文件传输的协议,适用于开销不大、不复杂的应用场合。TFTP协议专门为小文件传输而设计,只…

虚拟游戏理财 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 100分 题解: Java / Python / C 题目描述 在一款虚拟游戏中生活,你必须进行投资以增强在虚拟游戏中的资产以免被淘汰出局。 现有一家Bank,它提供有若干理财产品m,风险及…

line_profiler逐行分析代码时间

最近model训练有点瓶颈,GPU的利用率总是很低。所以看看能不能使用line_profiler来发现问题。 使用方式: 确保 line_profiler 正确安装: pip install line_profiler在需要分析的函数加上修饰器: from line_profiler import prof…

【PHP安全】PHP伪协议

PHP伪协议: file:// #访问本地文件系统http:// #访问HTTPs网址ftp:// #访问ftp URLphp:// #访问输入输出流zlib:// #压缩流data:// #数据(RFC 2397)ssh2:// #security shell2expect:// #处理交互式的流glob:// #查找匹配的文件路径phar:// #P…

33.使用ORDER BY排序

用ORDER BY子句排序行 ASC:升序排序,默认 DESC:降序排序 ORDER BY 子句在SELECT 语句的最后 在一个不明确的查询结果中排序返回的行。ORDER BY子句用于排序。如果使用了ORDER BY子句,它必须位于SQL语句的最后。 SELECT 语句的…

Siamese Network(孪生神经网络)详解

Siamese和Chinese有点像。Siam是古时候泰国的称呼,中文译作暹罗。Siamese也就是“暹罗”人或“泰国”人。Siamese在英语中是“孪生”、“连体”的意思,这是为什么呢?十九世纪泰国出生了一对连体婴儿,当时的医学技术无法使两人分离…

前端页面渲染机制

前端页面渲染机制是指在 web 开发中,浏览器如何将 HTML、CSS 和 JavaScript 转换为用户可视化的网页界面的过程。这个过程通常包括以下几个主要步骤: 加载 HTML: 首先,浏览器会获取 HTML 文件,并解析其结构。这个过程包括识别 HTM…

MySQL锁—全局锁、表级锁、行级锁详解

MySQL 锁 MySQL的锁按照锁的粒度可以分为全局锁、表级锁和行级锁。 一、全局锁 1. 概念 全局锁,是对整个数据库实例加锁,加锁后整个实例处于只读状态,后续的DML、DDL语句以及已经执行更新操作的事务提交语句都将被阻塞。 2. 应用场景 数据…

软件功能测试内容有哪些?湖南长沙软件测评公司分享

软件功能测试主要是验证软件应用程序的功能,且不管功能是否根据需求规范运行。是通过给出适当的输入值,确定输出并使用预期输出验证实际输出来测试每个功能。也可以看作“黑盒测试”,因为功能测试不用考虑程序内部结构和内部特性,…

MongoDB聚合运算符:$exp

文章目录 语法使用举例 $exp聚合运算符返回自然常数或欧拉数e的幂值&#xff08;次方&#xff09;的结果 语法 { $exp: <exponent> }<exponent>为指数&#xff0c;可以是任何数值表达式。 使用 如果参数为null或引用的字段不存在&#xff0c;$exp返回null&#…

【夏普利值——详细讲解】

夏普利值的介绍 沙普利值是合作博弈理论中的一个概念&#xff0c;由劳埃德-沙普利在1951年提出了这个概念&#xff0c;并因此在2012年获得了诺贝尔经济学奖。对于每个合作博弈&#xff0c;如联邦学习&#xff0c;可以将机构产生的模型的总提升在各个机构上形成一个有效的贡献分…

【iOS ARKit】PhysicsMotionComponent

使用 Physics BodyComponent 组件&#xff0c;通过设置物理参数、物理材质、施加作用力&#xff0c;能完全模拟物体在真实世界中的行为&#xff0c;这种方式的优点是遵循物理学规律、控制精确&#xff0c;但缺点是不直观。使用 PhysicsMotion Component组件则可以通过直接设置速…

Orange3数据预处理(清理特征组件)

清理特征 移除未使用的属性值和无用的属性&#xff0c;并对剩余的值进行排序。 输入 数据: 输入数据集 输出 数据: 过滤后的数据集 命名属性定义有时包含在数据中不出现的值。即使原始数据中没有这种情况&#xff0c;数据过滤、选择示例子集等操作也可能移除…

用python开发一个性能压测框架(超级简单)

用python开发一个性能压测框架&#xff08;超级简单&#xff09; 该框架是一个基础框架&#xff0c;超级简单&#xff0c;已经跑通&#xff0c;可以进行优化扩展 由于工作需要&#xff0c;最近开发了一款python性能压测框架&#xff0c;主要是对后端接口进行多线程压测 主要…

(二十五)Flask之MTVMVC架构模式Demo【重点:原生session使用及易错点!】

目录&#xff1a; 每篇前言&#xff1a;MTV&MVC构建一个基于MTV模式的Demo项目&#xff1a;蹦出一个问题&#xff1a; 每篇前言&#xff1a; &#x1f3c6;&#x1f3c6;作者介绍&#xff1a;【孤寒者】—CSDN全栈领域优质创作者、HDZ核心组成员、华为云享专家Python全栈领…