mul 与 reduce_sum 的优化实例

一、基础介绍

什么是 mul 与 reduce_sum?

mul 通常指元素级乘法(Element-wise Multiplication),它将两个形状相同的张量中对应位置的元素相乘,返回一个与原张量形状相同的新张量。

reduce_sum 是一种规约操作(Reduction Operation),它沿指定维度对张量的元素求和,从而 “压缩” 或 “减少” 张量的维度。如果不指定维度,则对所有元素求和,返回一个标量。

二、baseline 结构

onnx 可视化图如下:

对应代码如下:

class CustomNet(nn.Module):def __init__(self):super(CustomNet, self).__init__()def forward(self, a, b):# a: shape (1, 500, 7, 4, 13, 8)# b: shape (1, 500, 7, 4, 13, 256)# Step 1: Unsqueeze a -> (1, 500, 7, 4, 13, 8, 1)a = a.unsqueeze(-1)# Step 2: Reshape b -> (1, 500, 7, 4, 13, 8, 32)b = b.view(1, 500, 7, 4, 13, 8, 32)# Step 3: Mul (broadcast over last dim)out = a * b  # shape: (1, 500, 7, 4, 13, 8, 32)# # Step 4: ReduceSum over dim=2 (index 2 = 7 dim)out = out.sum(dim=2)  # shape: (1, 500, 4, 13, 8, 32)# # Step 5: ReduceSum over dim=1 (500 dim)out = out.sum(dim=1)  # shape: (1, 4, 13, 8, 32)# Step 6: Reshape to final outputout = out.view(-1, 13, 8, 32)  # 可根据需要调整最终输出 shapereturn outa = torch.randn(1, 500, 7, 4, 13, 8)
b = torch.randn(1, 500, 7, 4, 13, 256)
model = CustomNet()
output = model(a, b)

在征程 6M 上进行简单的模型编译与性能预估:

hb_compile -m mymodel.onnx --march nash-m --fast-perf

根据产出物得到预估 latency:2.97 ms

这个结构如何进行优化呢?

三、合并 reduce_sum

# Step 4: ReduceSum over dim=2 (index 2 = 7 dim)
out = out.sum(dim=2)  # shape: (1, 500, 4, 13, 8, 32)# Step 5: ReduceSum over dim=1 (500 dim)
out = out.sum(dim=1)  # shape: (1, 4, 13, 8, 32)

这两个 reducesum 能合并成一个,使用 dim=(1, 2)(即同时对 dim=1 和 dim=2 做 sum),前提是这两个维度的求和没有先后顺序依赖(即两个维度是独立的)

out = out.sum(dim=(1, 2))  # 一次性对 dim=1 和 dim=2 求和

PyTorch 中 。sum(dim=(1, 2)) 会按照给出的维度一次性执行 sum 操作,等价于逐个做 dim=2 然后 dim=1,因为 sum 是可交换的操作,最终结果形状完全相同。

优化后结构如下,可以看到确实少了一个 reducesum:

预估 latency: 1.75 ms

四、mul+reducesum 变成 conv

假设有两个张量:

  • a.shape = (B, C, H, W)
  • b.shape = (B, C, H, W)

常见操作是:

out = (a * b).sum(dim=[2, 3])  # 在 H 和 W 上求和,输出 shape: (B, C)# ----------细节---------------
import torch
import torch.nn as nn
a = torch.randn(1, 3, 8, 4) # 多维时,a的最后一维若与b不同,则只能是1,否则不能进行广播
b = torch.randn(1, 3, 8, 4)
c = a * b               # c的shape:torch.Size([1, 3, 8, 4])
d = c.sum(dim=[2,3])    # d的shape:torch.Size([1, 3])

注意:torch 中 a * b 是逐元素相乘(mul),而不是矩阵乘法(matmul),形状不匹配时会触发广播(复制对应列 or 行)

通过 深度卷积(depthwise convolution) 可以近似实现 Mul + ReduceSum 操作,等价的 Conv2d 实现方式,可以用 groups=B*C 的 conv2d 来实现上述操作:

import torch
import torch.nn.functional as Fdef conv_approx_mul_reducesum(a, b):B, C, H, W = a.shape# 把 b 变成卷积核,作为每个通道的 filterkernel = b.reshape(B * C, 1, H, W)# 输入 reshape 成 (1, B*C, H, W)input_ = a.reshape(1, B * C, H, W)# 深度卷积实现 mul+sum,输出 shape: (1, B*C, 1, 1)output = F.conv2d(input_, kernel, groups=B * C)# reshape 回 (B, C)return output.reshape(B, C)

conv2d 的过程是:

  • 对每个通道进行 乘法(卷积)
  • 然后在 kernel 区域内 求和

所以 F.conv2d(a, b, groups=B*C) 本质就是:对 a 和 b 逐元素相乘再求和 = Mul + ReduceSum

一致性验证:

import torch
import torch.nn as nn
import torch.nn.functional as Fa = torch.randn(1, 3, 8, 4) # 多维时,a的最后一维若与b不同,则只能是1,否则不能进行广播
b = torch.randn(1, 3, 8, 4)
c = a * b               # c的shape:torch.Size([1, 3, 8, 4])
d = c.sum(dim=[2,3])    # d的shape:torch.Size([1, 3])
print(d)def F_conv2d_approx_mul_reducesum(a, b):B, C, H, W = a.shape# 把 b 变成卷积核,作为每个通道的 filterkernel = b.reshape(B * C, 1, H, W)# 输入 reshape 成 (1, B*C, H, W)input_ = a.reshape(1, B * C, H, W)# 深度卷积实现 mul+sum,输出 shape: (1, B*C, 1, 1)output = F.conv2d(input_, kernel, groups=B * C)# reshape 回 (B, C)return output.reshape(B, C)
print(F_conv2d_approx_mul_reducesum(a,b))def nn_conv2d_approx_mul_reducesum(a, b):B, C, H, W = a.shape# 把 b 变成卷积核,作为每个通道的 filterkernel = b.reshape(B * C, 1, H, W)# 输入 reshape 成 (1, B*C, H, W)input_ = a.reshape(1, B * C, H, W)# 假设已有输入input_和卷积核kernel# kernel形状: (输出通道数, 输入通道数//groups, 核高, 核宽)# 例如:groups=B*C时,输入通道数需为groups的倍数out_channels = kernel.size(0)in_channels = kernel.size(1) * (B * C)  # 输入通道数 = 每组通道数 * groupskernel_size = (kernel.size(2), kernel.size(3))# 创建nn.Conv2d模块conv_layer = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,groups=B * C,bias=False  # 若F.conv2d未用偏置)# 将预定义的kernel赋值给conv_layer的权重conv_layer.weight.data = kernel  # 注意:需确保kernel形状与nn.Conv2d的weight格式一致# 深度卷积实现 mul+sum,输出 shape: (1, B*C, 1, 1)output = conv_layer(input_)# reshape 回 (B, C)return output.reshape(B, C)
print(nn_conv2d_approx_mul_reducesum(a,b))

输出:

tensor([[-0.3991,  0.2382, -8.5925]])
tensor([[-0.3991,  0.2382, -8.5925]])
tensor([[-0.3991,  0.2382, -8.5925]], grad_fn=<ViewBackward0>)

可以看到,结果确实一样。

真正部署时,不太建议这么做,因为小尺寸没必要(快不了多少),大尺寸硬件不支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/952119.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《代码大全》读后感:从 “功能实现者” 到 “责任承担者” 的思维跃迁

对于拥有 5 年开发经验的我来说,《代码大全 2》第一章更像是一次 “实践复盘”,让我对 “软件构建的责任” 有了更深的体悟。书中提到 “构建阶段决定了软件的质量上限”,这句话精准概括了我过往项目中的教训:曾参…

企业网站建设服务商:2025年最佳选择指南与行业洞察

摘要 企业网站建设行业在2025年持续快速发展,数字化转型推动中小企业对高效、性价比高的建站服务需求激增。本文基于行业数据和用户反馈,整理了2025年企业网站建设服务商排名前十的榜单,为中小企业提供参考。排名综…

2025年市场上微信小程序服务商:十大顶尖企业权威评测与选择指南

摘要 随着数字化转型加速,2025年微信小程序服务市场呈现爆发式增长,中小企业对高效、低成本互联网工具需求激增。本文基于行业数据和技术实力,深度评测十大微信小程序服务商,为您的企业选择提供权威参考。文末附有…

记录一下,关于前端控制并发的思路

看了前端很多文章我感觉并发不应该只是控制几个接口去发送,应该考虑到每个接口完成的时间是不同的,所以今天我试着写了一个如果并发中接口完成了请求就继续发送其他接口的js ,简单测试了一下感觉没问题,还请各位提…

Linux 交叉编译(toolchain) ARM 版 lib pcap.so 库

前言全局说明libpcap.so 编译一、说明 环境: CentOS Linux 7 (Core) Linux localhost.localdomain 3.10.0-1160.el7.x86_64 #1 SMP Mon Oct 19 16:18:59 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux二、下载源码: 官网:…

Codeforces Pinely Round 5(div.1 + div.2) A~D题解

写在开头 有不足的地方各位佬多多指教呀! 持续更新(bushi) A 题面 给定李华的初始rating \(R\) ,div.2的计分上限 \(X\),李华每次rating的变化最大值 \(D\),以及cf比赛的次数 \(n\),问李华最多可以正式参加多少…

Linux 交叉编译(toolchain) ARM版 libc.so 库

前言全局说明编译 libc.so一、说明 环境: CentOS Linux 7 (Core) Kernel 3.10.0-1127.el7.x86_64 on an x86_64二、下载源码: 官网: http://ftp.gnu.org 源码下载: http://ftp.gnu.org/gnu/libc/ 历史版本: http://…

revit api事件

revit api事件DocumentOpened事件public class Application_DocumentOpened : IExternalApplication {public IExternalApplication.Result OnStartup(ControlledApplication application){try{//注册事件application.…

《我测了5个排版工具后,终于把时间还给了内容创作》

《我测了5个排版工具后,终于把时间还给了内容创作》每天在公众号后台调整行间距、寻找合适模板,跨平台发稿时反复复制粘贴,往往要耗费 1-2 小时 —— 作为经常处理排版工作的运营,我实际测试了 5 款编辑器,希望找…

EDKII工程结构介绍

EDK2工程结构介绍一 EDk2开发环境的安装目录一、EDK2工程目录的一级结构 二、常用的目录文件 2.1 BaseTools--构建工具链 2.2 Conf--配置目录 2.3 MdePkg--基础核心包 2.4 MdeModulePkg -- 常用模块包 2.5 OvmfPkg…

《程序员修炼之道:从小工到专家》读后感3

后半部分围绕“成长与协作”展开,描绘了从“合格程序员”到“专家型开发者”的进阶路径。这部分内容跳出了单纯的技术和工具层面,聚焦于职业格局、团队协作和持续成长,让我对“专家”的定义有了全新的认知:真正的专…

Vue3组件代码编写遵循1.0

优秀的Vue组件代码编写应该遵循以下标准:组件设计原则单一职责原则 (Single Responsibility Principle)每个组件只负责一个功能 保持组件的纯粹性和可复用性 避免组件过于臃肿开放封闭原则 (Open-Closed Principle)对…

《程序员修炼之道:从小工到专家》读后感2

中间部分将焦点转向“工具与方法”,为程序员提供了一套可落地的实践指南。如果说前半部分是思想层面的觉醒,这部分就是行动层面的赋能,让我学会运用科学的工具和流程,将“把事做好”的愿望转化为实实在在的能力,深…

《程序员修炼之道:从小工到专家》读后感1

前半部分聚焦“原则与态度”,为初入行业或陷入迷茫的程序员点亮了职业道路的第一盏灯。书中没有堆砌复杂的技术细节,而是从思维模式和工作习惯入手,剖析了成为优秀开发者的核心前提,让我对“程序员”的职业认知实现…

2025西南地区优质温室大棚厂家精选推荐:深度解析重庆青程技术实力!蔬菜大棚厂家推荐

在设施农业快速发展的当下,温室大棚的适配性、耐用性直接关系到种植户的生产效率与收益。为给中小种植户、家庭农场及初创农业项目提供实用参考,结合地域气候适配性、产品实用性、用户口碑等多维度,具优势的温室大棚…

通义灵码助力美图构建AI驱动研发体系,助力提升研发能效和流程智能化

美图公司是一家以 “让艺术与科技美好交汇”为使命的科技创新公司,致力于为用户提供多元化、高质量的美图产品和服务。公司拥有包括美图秀秀、美颜相机、美图设计室、美图云修等在内的多款知名应用。截至2025年6月30日…

2025修护/二硫化硒去屑/香氛/控油蓬松/ 洗发水推荐榜:MASIL 玛丝兰(悦己容)五星领跑!长效去屑 + 温和修护,3 牌凭特色突围​

随着 2025 年消费者对头皮护理的需求从 “短期去屑” 转向 “长效控油、温和修护、屏障保护”,二硫化硒去屑洗发水作为头屑问题的核心解决方案,愈发注重 “浓度科学适配、成分协同增效、场景化护理”。综合去屑长效性…

数列分块学习笔记(锣鼓梳理额粉筷入门模板)

数列分块入门1 我们预处理出每个点所在的区块,预处理每个区块的左端点和右端点。 对于添加操作,我们先判断是否在同一区间,如果是的话就在区间里面暴力重构。 如果不是,那么就对于一整块要处理的区间的左边神域和右…

2025凝汽器/换热器/空预器/板式换热器/管式换热器/空冷岛/电磁脉冲/胶球/热网加热器/低低温省煤器/清洗设备/服务推荐榜:郑州赛为机电五星领跑!在线清洗 + 定制化,3 企凭特色突围​

随着 2025 年核电、火电、化工等高能耗领域对 “清洗设备不停机运行、节能降耗、场景定制化” 需求深化,清洗设备已从 “传统离线清洗” 转向 “在线智能清洗、全工况适配”。综合技术创新性、工况适配度、服务覆盖及…

claude-ide搭建

claude-ide搭建 说明 官方:https://www.claudeide.net/zh 按照之前的搭建好Windows的node.js环境。 步骤 安装 # 安装 Node.js 18+, 然后运行: npm install -g @anthropic-ai/claude-code配置 # cmd控制台设置 set AN…