解决qnn htp 后端不支持boolean 数据类型的方法。

一、背景

  1.1 问题原因

  Qnn 模型在使用fp16的模型转换不支持类型是boolean的cast 算子,因为 htp 后端支持量化数据类型或者fp16,不支持boolean 类型。

${QNN_SDK_ROOT_27}/bin/x86_64-linux-clang/qnn-model-lib-generator -c ./bge_small_fp16.cpp -b ./bge_small_fp16.bin -o output-so-small

也就是图中的算子不支持。

尝试了很多版本,后端,都不支持。没办法只能算子替换了。

1.2 替换算子

初步思路:

       Sub↓Cast (to bool)↓Cast (to float32)    (另外一个输入,假设是 y)↓                  ↓Mul              Mul (1 - mask)↓                  ↓Add↓Output
  1. 先做一个 Greater 比较,生成 0/1 tensor

  2. 再用这个 0/1 tensor 进行 (cond * x) + ((1-cond) * y) 操作, Where(cond, x, y) = cond * x + (1 - cond) * y 可以用 Cast + Mul + Sub + Add 基础算子实现。

  3. 但是生成的还是有boolean 类型数据

不要 Greater (即不要比较生成bool类型)

不要 BOOL tensor (因为有些平台对BOOL类型支持不好,比如QNN/DSP/NPU)

直接从 float tensor 生成 0/1 的 float tensor!

改进思路:

可以直接用 Clip + Sign 这种基础算子来实现!

比如:

  • Sign(x)

    • 如果 x > 0,输出 1

    • 如果 x == 0,输出 0

    • 如果 x < 0,输出 -1

  • Clip(Sign(x), 0.0, 1.0)

    • 把负数剪到 0

    • 正数(1)保留为 1

这样就完美地直接生成了一个 全是 0 或 1FLOAT tensor! ✅ 没有 BOOL 类型,✅ 没有 Greater 节点,✅ 没有 Cast,✅ 全是 float。

real_cond_input ---> Sign ---> Clip(0.0, 1.0) ---> mask (float 0/1 tensor)

二、算子代码实现

1.1 替换算子

import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as npdef add_value_info(graph, name, dtype, shape):"""辅助函数:添加中间 tensor 的 shape 和 dtype"""vi = helper.make_tensor_value_info(name, dtype, shape)graph.value_info.append(vi)def add_constant(graph, base_name, value, dtype, shape):const_name = base_name + "_value"const_tensor = helper.make_tensor(name=const_name,data_type=dtype,dims=shape,vals=value)const_node = helper.make_node('Constant',inputs=[],outputs=[const_name],value=const_tensor)graph.node.append(const_node)add_value_info(graph, const_name, dtype, shape)return const_name
def replace_where_and_cast(model_path, output_path):"""替换 onnx 中的 Where 和 Cast 节点,保持功能等效"""# 读取模型model = onnx.load(model_path)nodes = model.graph.nodeprint("old model node number" + str(len(model.graph.node)))new_nodes = []nodes_to_remove = []input_shape = [1,1, 512, 512]for node in model.graph.node:if node.op_type == "Where":# 记录要移除的原始 Wherenodes_to_remove.append(node)# Where输入:[condition, x, y]cond_input = node.input[0]print(cond_input)x_input = node.input[1]print(x_input)y_input = node.input[2]print(y_input)output_name = node.output[0]print(output_name)# 处理可能前面有 Cast 的情况real_cond_input = cond_inputfor sub_node in model.graph.node:if sub_node.output and sub_node.output[0] == cond_input and sub_node.op_type == "Cast":real_cond_input = sub_node.input[0]nodes_to_remove.append(sub_node)break# ========== 关键步骤 ==========# 1. Signsign_output = real_cond_input + "_sign"sign_node = helper.make_node('Sign',inputs=[real_cond_input],outputs=[sign_output],name ="sign_add_my")new_nodes.append(sign_node)add_value_info(model.graph, sign_output, TensorProto.FLOAT, input_shape)# 2. Clip(0,1)clip_output = real_cond_input + "_clip"clip_min_tensor_name = real_cond_input + "_min_value"clip_min_initializer = numpy_helper.from_array(np.zeros(1, dtype=np.float32),name=clip_min_tensor_name)clip_max_tensor_name = real_cond_input + "_max_value"clip_max_initializer = numpy_helper.from_array(np.ones(1, dtype=np.float32),name=clip_max_tensor_name)model.graph.initializer.append(clip_min_initializer)model.graph.initializer.append(clip_max_initializer)# min_val_const_node = add_constant(model.graph, "min_value", 0, TensorProto.FLOAT, input_shape)# max_val_const_node = add_constant(model.graph, "max_value", 1, TensorProto.FLOAT, input_shape)clip_node = helper.make_node('Clip',inputs=[sign_output, clip_min_tensor_name, clip_max_tensor_name],outputs=[clip_output],name="clip_add_my")new_nodes.append(clip_node)add_value_info(model.graph, clip_output, TensorProto.FLOAT, input_shape)# 3. 生成 (1 - mask)one_tensor_name = real_cond_input + "_one"one_initializer = numpy_helper.from_array(np.ones(input_shape, dtype=np.float32),name=one_tensor_name)model.graph.initializer.append(one_initializer)one_minus_mask_output = real_cond_input + "_one_minus_mask"sub_node = helper.make_node('Sub',inputs=[one_tensor_name, clip_output],outputs=[one_minus_mask_output],name="sub_my")new_nodes.append(sub_node)add_value_info(model.graph, one_minus_mask_output, TensorProto.FLOAT, input_shape)# 4. mask * xmask_mul_x_output = real_cond_input + "_mask_mul_x"mul1_node = helper.make_node('Mul',inputs=[clip_output, x_input],outputs=[mask_mul_x_output],name="mul_my")new_nodes.append(mul1_node)add_value_info(model.graph, mask_mul_x_output, TensorProto.FLOAT, input_shape)# 5. (1-mask) * yone_minus_mask_mul_y_output = real_cond_input + "_one_minus_mask_mul_y"mul2_node = helper.make_node('Mul',inputs=[one_minus_mask_output, y_input],outputs=[one_minus_mask_mul_y_output],name="mul_my2")new_nodes.append(mul2_node)add_value_info(model.graph, one_minus_mask_mul_y_output, TensorProto.FLOAT, input_shape)# 6. 加起来得到最终输出add_node = helper.make_node('Add',inputs=[mask_mul_x_output, one_minus_mask_mul_y_output],outputs=[output_name],name="add_my")new_nodes.append(add_node)# output shape 已经有定义,不需要额外addelif node.op_type == 'Cast':# 如果是 Where 的 Cast,不保留if any(wn.input[0] == node.output[0] for wn in nodes if wn.op_type == 'Where'):print(f"Skipping Cast node: {node.name}")continueelse:new_nodes.append(node)else:new_nodes.append(node)# 移除旧节点for node in nodes_to_remove:model.graph.node.remove(node)# 更新新的节点列表model.graph.ClearField('node')model.graph.node.extend(new_nodes)print("new model node number" + str(len(model.graph.node)))# 保存新的模型onnx.save(model, output_path)if __name__ == "__main__":model_path = "./bge_small_model_simple.onnx"output_path = "./bge_replace_cast_where2.onnx"replace_where_and_cast(model_path, output_path)

 

2.2 运行原始模型和算子替换之后的模型

def run_bge_small_model_onnx():model = AutoModel.from_pretrained("BAAI/bge-small-zh-v1.5")tokenizers = AutoTokenizer.from_pretrained("BAAI/bge-small-zh-v1.5")input_data = "ZhongGuo, nihao, 日本再见, good cat!"device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)model.eval()input_tensor_data = tokenizers(input_data, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(device)with torch.no_grad():output = model(**input_tensor_data)print("oringal model putput")output_data = output.last_hidden_state.flatten().tolist()[:100]print(len(output.last_hidden_state.flatten().tolist()))print(output_data)print("run modify model")# 步骤 2:加载 ONNX 模型model_path = './bge_replace_cast_where2.onnx'  # 替换为你的 ONNX 模型文件路径session = ort.InferenceSession(model_path)# 步骤 3:准备输入数据# 假设模型的输入是一个形状为 (1, 3, 224, 224) 的浮点张量input_name1 = session.get_inputs()[0].nameprint(input_name1)input_data1 = input_tensor_data["input_ids"].numpy()input_name2 = session.get_inputs()[1].nameinput_data2 = input_tensor_data["attention_mask"].numpy()print(input_name2)input_name3 = session.get_inputs()[2].nameinput_data3 = input_tensor_data["token_type_ids"].numpy()print(input_name3)# 步骤 4:运行模型并获取输出replace_model_output = session.run(None, {input_name1: input_data1, input_name2: input_data2, input_name3: input_data3})# 打印输出结果print("replace_model_output shape:", replace_model_output[0].shape)print("replace_model_output data:", replace_model_output[0])replace_model_output_data = replace_model_output[:100]print(len(replace_model_output))print(replace_model_output_data)np.array(replace_model_output).tofile("last_output-onnx_bge_small_replace.raw")

2.3 原始模型和替换算子模型精度对齐


def compare_nchw_data(nchw_file, nchw_file2):data_nchw = read_bin_fp32(nchw_file, shape=[1, 512, 512])print("NCHW 原始数据形状:", data_nchw.shape)print("NCHW 数据统计 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw.min(), data_nchw.max(), data_nchw.mean()))data_nchw2 = read_bin_fp32(nchw_file2, shape=[1, 512, 512])print("NHWC2 原始数据形状:", data_nchw2.shape)print("NHWC2 数据统计 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw2.min(), data_nchw2.max(), data_nchw2.mean()))diff = data_nchw - data_nchw2print("\n==== 差异对比 ====")print("差值 min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(diff.min(), diff.max(), diff.mean()))print(diff)# ==== 打印前100个数据 ====onnx_output_flat = data_nchw.flatten()onnx_output_flat2 = data_nchw2.flatten()print("\n--- 前100个元素 ---")for i in range(100):print(f"[{i}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 打印后100个数据 ====print("\n--- 后100个元素 ---")for i in range(-100, 0):idx = len(onnx_output_flat) + iprint(f"[{idx}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 可选:统计误差 ====max_diff = np.max(onnx_output_flat2 - onnx_output_flat)mean_diff = np.mean(onnx_output_flat2 - onnx_output_flat )min_diff = np.min(onnx_output_flat2 -onnx_output_flat)print(f"\n 总元素数: {onnx_output_flat.size}")print(f" 最大误差: {max_diff}")print(f" 最小误差: {min_diff}")print(f" 平均误差: {mean_diff}")

2.4 对齐结果展示

 

 

结果对齐了,表示模型替换成功了。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/80696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Three.js搭建自己的3Dweb模型(从0到1无废话版本)

教学视频参考&#xff1a;B站——Three.js教学 教学链接&#xff1a;Three.js中文网 老陈打码 | 麒跃科技 一.什么是Three.js&#xff1f; Three.js​ 是一个基于 JavaScript 的 ​3D 图形库&#xff0c;用于在网页浏览器中创建和渲染交互式 3D 内容。它基于 WebGL&#xff0…

PostgreSQL WAL 幂等性详解

1. WAL简介 WAL&#xff08;Write-Ahead Logging&#xff09;是PostgreSQL的核心机制之一。其基本理念是&#xff1a;在修改数据库数据页之前&#xff0c;必须先将这次修改操作写入到WAL日志中。 这确保了即使发生崩溃&#xff0c;数据库也可以根据WAL日志进行恢复。 恢复的核…

git提交规范记录,常见的提交类型及模板、示例

Git提交规范是一种约定俗成的提交信息编写标准&#xff0c;旨在使代码仓库的提交历史更加清晰、可读和有组织。以下是常见的Git提交类型及其对应的提交模板&#xff1a; 提交信息的基本结构 一个标准的Git提交信息通常包含以下三个主要部分&#xff1a; Header‌&#xff1a;描…

FastAPI系列06:FastAPI响应(Response)

FastAPI响应&#xff08;Response&#xff09; 1、Response入门2、Response基本操作设置响应体&#xff08;返回数据&#xff09;设置状态码设置响应头设置 Cookies 3、响应模型 response_model4、响应类型 response_classResponse派生类自定义response_class 在“FastAPI系列0…

每日一题(小白)模拟娱乐篇33

首先&#xff0c;理解题意是十分重要的&#xff0c;我们是要求最短路径&#xff0c;这道题可以用dfs&#xff0c;但是题目给出的数据是有规律的&#xff0c;我们可以尝试模拟的过程使用简单的方法做出来。每隔w数字就会向下转向&#xff0c;就比如题目上示例的w6&#xff0c;无…

哈希封装unordered_map和unordered_set的模拟实现

文章目录 &#xff08;一&#xff09;认识unordered_map和unordered_set&#xff08;二&#xff09;模拟实现unordered_map和unordered_set2.1 实现出复用哈希表的框架2.2 迭代器iterator的实现思路分析2.3 unordered_map支持[] &#xff08;三&#xff09;结束语 &#xff08;…

Java学习-Java基础

1.重写与重载的区别 重写发生在父子类之间,重载发生在同类之间构造方法不能重写,只能重载重写的方法返回值,参数列表,方法名必须相同重载的方法名相同,参数列表必须不同重写的方法的访问权限不能比父类方法的访问权限更低 2.接口和抽象类的区别 接口是interface,抽象类是abs…

BG开发者日志0427:故事的起点

1、4月26日晚上&#xff0c;BG项目的gameplay部分开发完毕&#xff0c;后续是细节以及试玩版优化。 开发重心转移到story部分&#xff0c;目前刚开始&#xff0c; 确切地说以前是长期搁置状态&#xff0c;因为过去的四个月中gameplay部分优先开发。 --- 2、BG这个项目的起点…

头歌实训之游标触发器

&#x1f31f; 各位看官好&#xff0c;我是maomi_9526&#xff01; &#x1f30d; 种一棵树最好是十年前&#xff0c;其次是现在&#xff01; &#x1f680; 今天来学习C语言的相关知识。 &#x1f44d; 如果觉得这篇文章有帮助&#xff0c;欢迎您一键三连&#xff0c;分享给更…

【深度学习】多头注意力机制的实现|pytorch

博主简介&#xff1a;努力学习的22级计算机科学与技术本科生一枚&#x1f338;博主主页&#xff1a; Yaoyao2024往期回顾&#xff1a;【深度学习】注意力机制| 基于“上下文”进行编码,用更聪明的矩阵乘法替代笨重的全连接每日一言&#x1f33c;: 路漫漫其修远兮&#xff0c;吾…

java16

1.API续集 可以导入别人写好的clone的jar包 注意&#xff1a;方法要有调用者&#xff0c;如果调用者是null就会报错 2.如何导入别人写好的jar包 复制jar包然后粘贴在lib里面&#xff0c;然后右键点击jar包再点击下面的add 3.关于打印java中的引用数据类型

PostgreSQL的扩展 credcheck

PostgreSQL的扩展 credcheck credcheck 是 PostgreSQL 的一个安全扩展&#xff0c;专门用于强制实施密码策略和凭证检查&#xff0c;特别适合需要符合安全合规要求的数据库环境。 一、扩展概述 1. 主要功能 强制密码复杂度要求防止使用常见弱密码密码过期策略实施密码重复使…

MyBatis中的@Param注解-如何传入多个不同类型的参数

mybatis中参数识别规则 默认情况下,MyBatis 会按照参数位置自动分配名称:param1, param2, param3, ...或者 arg0, arg1。 // Mapper 接口方法 User getUserByIdAndName(Integer id, String name); 以上接口在XML中只能通过param1或者arg0这样的方式来引用,可读性差。 &l…

DIFY教程第一集:安装Dify配置环境

一、Dify的介绍 https://dify.ai/ Dify 是一款创新的智能生活助手应用&#xff0c;旨在为您提供便捷、高效的服务。通过人工智能技术&#xff0c; Dify 可以实现语音 助手、智能家居控制、日程管理等功能&#xff0c;助您轻松应对生活琐事&#xff0c;享受智慧生活。简约的…

5、Rag基础:RAG 专题

RAG 简介 什么是检索增强生成? 检索增强生成(RAG)是指对大型语言模型输出进行优化,使其能够在生成响应之前引用训练数据来源之外的权威知识库。大型语言模型(LLM)用海量数据进行训练,使用数十亿个参数为回答问题、翻译语言和完成句子等任务生成原始输出。在 LLM 本就强…

GAMES202-高质量实时渲染(homework1)

目录 Homework1shadow MapPCF(Percentage Closer Filter)PCSS(Percentage Closer Soft Shadow) GitHub主页&#xff1a;https://github.com/sdpyy1 作业实现:https://github.com/sdpyy1/CppLearn/tree/main/games202 Homework1 shadow Map 首先需要完成MVP矩阵的构造&#xf…

JDK(Ubuntu 18.04.6 LTS)安装笔记

一、前言 本文与【MySQL 8&#xff08;Ubuntu 18.04.6 LTS&#xff09;安装笔记】同批次&#xff1a;先搭建数据库&#xff0c;再安装JDK&#xff0c;后面肯定就是部署Web应用&#xff1a;典型的单机部署。“麻雀虽小五脏俱全”&#xff0c;善始善终&#xff0c;还是记下来吧。…

软件测试之接口测试常见面试题

一、什么是(软件)接口测试? 接口测试&#xff1a;是测试系统组件间接口的一种测试方法 接口测试的重点&#xff1a;检查数据的交换&#xff0c;数据传递的正确性&#xff0c;以及接口间的逻辑依赖关系 接口测试的意义&#xff1a;在较早期开展&#xff0c;在软件开发的同时…

Lua 第11部分 小插曲:出现频率最高的单词

在本章中&#xff0c;我们要开发一个读取并输出一段文本中出现频率最高的单词的程序。像之前的小插曲一样&#xff0c;本章的程序也十分简单但是也使用了诸如迭代器和匿名函数这样的高级特性。 该程序的主要数据结构是一个记录文本中出现的每一个单词及其出现次数之间关系的表。…

软件项目进度管理活动详解

目录 1. 活动定义&#xff08;Activity Definition&#xff09; 2. 活动排序&#xff08;Activity Sequencing&#xff09; 3. 活动资源估算&#xff08;Activity Resource Estimating&#xff09; 4. 活动历时估算&#xff08;Activity Duration Estimating&#xff09; …