CUDA高性能计算系列10:实战手写深度学习算子(Softmax)

CUDA高性能计算系列10:实战手写深度学习算子(Softmax)

摘要:纸上得来终觉浅,绝知此事要躬行。学了这么多优化技巧,是时候检验真功夫了。本篇我们将深入深度学习中最常见的算子之一——Softmax。看似简单的公式背后,隐藏着数值溢出的陷阱和并行归约的挑战。我们将手写一个能够与 PyTorch 原生性能抗衡的 Softmax Kernel。


1. Softmax 的数学原理与挑战

Softmax 函数将一个向量x xx映射为概率分布y yy
y i = e x i ∑ j e x j y_i = \frac{e^{x_i}}{\sum_{j} e^{x_j}}yi=jexjexi

1.1 数值稳定性问题 (Numerical Stability)

直接计算e x i e^{x_i}exi非常危险。
如果x i = 100 x_i = 100xi=100,则e 100 ≈ 2.6 × 10 43 e^{100} \approx 2.6 \times 10^{43}e1002.6×1043,这在 FP32 范围内没问题。
但如果x i = 1000 x_i = 1000xi=1000,则e 1000 → ∞ e^{1000} \to \inftye1000(Inf),导致 NaN 错误。

解决方案:减去最大值。
y i = e x i − max ⁡ ( x ) ∑ j e x j − max ⁡ ( x ) y_i = \frac{e^{x_i - \max(x)}}{\sum_{j} e^{x_j - \max(x)}}yi=jexjmax(x)eximax(x)
这样所有指数的指数项都在( − ∞ , 0 ] (-\infty, 0](,0]之间,结果在( 0 , 1 ] (0, 1](0,1]之间,永远不会上溢。

1.2 计算流程

这就将一个 Softmax 变成了三个阶段的计算:

  1. Reduce Max: 找到当前行的最大值m mm
  2. Reduce Sum: 计算S = ∑ e x i − m S = \sum e^{x_i - m}S=exim
  3. Element-wise Update: 计算y i = e x i − m / S y_i = e^{x_i - m} / Syi=exim/S

这就意味着我们需要遍历数据三次!如何高效地由 GPU 完成?


2. 架构设计:Grid, Block, Warp

假设输入张量形状为[Batch_Size, Dim]
通常Batch_Size很大,Dim变化范围广(从 100 到 10000+)。

2.1 策略:一行一个 Block

  • Grid Size:Batch_Size。每个 Block 处理一行数据。
  • Block Size: 256 或 1024。

如果Dim很小(< 1024),一个 Block 刚好能装下,直接用 Shared Memory 归约。
如果Dim很大,Block 需要循环处理(Grid-Stride Loop 变体)。


3. Kernel 实现:One-Pass 还是 Three-Pass?

为了教学清晰,我们先实现一个标准的Three-Pass逻辑,但在同一个 Kernel 内完成(避免多次启动 Kernel 的开销)。

#include<cuda_runtime.h>#include<math.h>// 辅助函数:Warp 内求最大值__device__floatwarpReduceMax(floatval){for(intoffset=16;offset>0;offset/=2)val=fmaxf(val,__shfl_down_sync(0xffffffff,val,offset));returnval;}// 辅助函数:Warp 内求和__device__floatwarpReduceSum(floatval){for(intoffset=16;offset>0;offset/=2)val+=__shfl_down_sync(0xffffffff,val,offset);returnval;}__global__voidsoftmax_kernel(float*input,float*output,intdim){// 1. 设置索引// blockIdx.x 对应 batch 维度(行号)introw_idx=blockIdx.x;// 指向当前行的起始地址float*row_input=input+row_idx*dim;float*row_output=output+row_idx*dim;// 2. 阶段一:求最大值 (Reduce Max)floatmax_val=-INFINITY;// 循环处理,防止 dim > blockDim.xfor(inti=threadIdx.x;i<dim;i+=blockDim.x){max_val=fmaxf(max_val,row_input[i]);}// Block 内规约最大值// 这里使用 Shared Memory 进行 Block 级规约(简化版,假设 Block=256,1个Warp处理不了)// 为了简单,我们只展示 Warp 级规约逻辑,实际需配合 Shared Memorymax_val=warpReduceMax(max_val);// 通过 Shared Memory 广播最大值给所有线程__shared__floats_max;if(threadIdx.x==0)s_max=max_val;__syncthreads();max_val=s_max;// 3. 阶段二:求指数和 (Reduce Sum)floatsum=0.0f;for(inti=threadIdx.x;i<dim;i+=blockDim.x){sum+=expf(row_input[i]-max_val);}sum=warpReduceSum(sum);__shared__floats_sum;if(threadIdx.x==0)s_sum=sum;__syncthreads();sum=s_sum;// 4. 阶段三:计算最终结果for(inti=threadIdx.x;i<dim;i+=blockDim.x){row_output[i]=expf(row_input[i]-max_val)/sum;}}

3.1 深度优化:Online Softmax

传统的 Softmax 需要遍历数据 3 次(Max -> Sum -> Update)。
有一种算法叫Online Softmax,利用数学技巧只需要遍历 2 次甚至更少。

公式推导:
维护当前的局部最大值m mm和局部和d dd
当遇到一个新的元素x xx时:

  • x > m x > mx>mm n e w = x m_{new} = xmnew=x,d n e w = d × e m − x + 1 d_{new} = d \times e^{m - x} + 1dnew=d×emx+1
  • x ≤ m x \le mxmm n e w = m m_{new} = mmnew=m,d n e w = d + e x − m d_{new} = d + e^{x - m}dnew=d+exm

这种方法可以在一次遍历中同时更新最大值和和,极大减少 Global Memory 访问。


4. 性能瓶颈分析

  1. Memory Bound: Softmax 是典型的Element-wise操作,计算量很小(也就 exp 和 div),主要时间都花在读写内存上。
  2. 优化方向
    • 确保 Global Memory 的合并访问(我们已经做到了,行内元素是连续的)。
    • 尽量把数据留在寄存器或 Shared Memory 中,避免重复读取 input。

5. 向量化读取 (Vectorized Load)

在处理 FP32 时,我们可以使用float4类型,一次读取 128 bit(4 个 float)。这能显著提高带宽利用率,减少指令数。

// 重新解释指针float4*vec_input=reinterpret_cast<float4*>(row_input);// 每次处理 4 个元素float4 data=vec_input[threadIdx.x];// ... 分别处理 data.x, data.y, data.z, data.w ...

限制:要求Dim必须是 4 的倍数,且地址必须对齐。实际工程中需要处理边界条件。


6. 总结与下篇预告

编写一个高性能的 Softmax 算子,不仅需要 CUDA 编程技巧(Shared Memory, Warp Shuffle),还需要深厚的数值分析功底(防止溢出)和算法优化思路(Online Softmax)。

至此,我们的 Kernel 代码已经能够跑在 GPU 上了。但是,怎么让 Python 里的 PyTorch 调用它呢?难道每次都要把数据存成文件,用 C++ 跑完再读回来吗?

当然不是!
下一篇CUDA系列11_PyTorch自定义C++扩展(Binding),我们将打通任督二脉,教你使用torch.utils.cpp_extension将我们写的 CUDA Kernel 编译成 Python 模块。届时,你只需要import my_cuda_ops,就能在 Python 里直接享用你亲手打造的高性能算子!


参考文献

  1. Milakov, M., & Gimelshein, N.Online Normalizer Calculation for Softmax. arXiv:1805.02867.
  2. OneFlow Team.How to Implement an Efficient Softmax Kernel.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1136958.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从0到1搭建实时日志监控系统:基于WebSocket + Elasticsearch的实战方案

1. 背景与痛点在开发分布式系统时&#xff0c;日志分散在多个服务节点中&#xff0c;传统轮询查询方式存在延迟高、资源浪费的问题。某次线上故障中&#xff0c;因未能实时发现错误日志&#xff0c;导致问题排查时间延长2小时。因此&#xff0c;决定自研一套低成本、实时性高的…

协同过滤性能优化技巧:高并发场景应用

如何让协同过滤扛住百万QPS&#xff1f;高并发推荐系统的实战优化之路 你有没有遇到过这样的场景&#xff1a;双十一刚到&#xff0c;首页推荐接口突然响应变慢&#xff0c;P99延迟飙升到500ms以上&#xff0c;用户开始抱怨“怎么老是推我不感兴趣的东西”&#xff1f;后台监控…

零基础掌握AUTOSAR诊断协议栈(UDS over CAN)

零基础吃透AUTOSAR诊断协议栈&#xff1a;从UDS到CAN&#xff0c;拆解整车刷写与故障读取的底层逻辑 你有没有遇到过这样的场景&#xff1f; 产线上的ECU突然无法刷写&#xff0c;诊断仪反复提示“安全访问拒绝”&#xff1b; 售后反馈某车型OBD灯常亮&#xff0c;但用标准工…

医疗用AutoGluon自动建模

&#x1f4dd; 博客主页&#xff1a;jaxzheng的CSDN主页 医疗AutoGluon&#xff1a;自动化建模的潜力与伦理暗礁目录医疗AutoGluon&#xff1a;自动化建模的潜力与伦理暗礁 引言&#xff1a;自动化浪潮下的医疗AI新边疆 一、技术应用场景&#xff1a;从理论到临床的实践价值 1.…

通俗解释nmodbus4在.NET Framework与Core的区别

一文讲透 nModbus4 在 .NET Framework 和 .NET Core 中的真实差异工业现场的设备通信&#xff0c;从来不是“插上线就能跑”的简单事。当你在树莓派上部署一个 Modbus 网关服务&#xff0c;却发现串口打不开&#xff1b;或者把原本运行良好的上位机程序从 Windows 迁移到 Linux…

【图像隐写】基于matlab快速四元数通用极坐标复指数变换的彩色图像零水印【含Matlab源码 14889期】

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;欢迎来到海神之光博客之家&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49…

大规模数据检索优化:elasticsearch官网核心要点

如何让 Elasticsearch 在 PB 级数据下依然快如闪电&#xff1f;官方最佳实践全拆解你有没有遇到过这样的场景&#xff1a;凌晨三点&#xff0c;监控突然报警——Elasticsearch 集群 CPU 暴涨、查询延迟飙升到几秒甚至超时。翻看日志才发现&#xff0c;某个“看起来无害”的聚合…

【车辆控制】铰接重型车辆的稳健路径跟随控制【含Matlab源码 14890期】

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;Matlab武动乾坤博客之家&#x1f49e;…

AI全景之第十二章第三节:光子计算、量子计算与AI

12.3 新型计算范式:光子计算、量子计算与AI 当前AI技术的飞速发展,尤其是大模型的持续迭代,对算力提出了指数级增长的需求。传统电子计算基于电子的电荷特性进行信息处理,受限于摩尔定律的放缓、能耗过高、传输延迟等固有瓶颈,已难以支撑下一代AI的发展。 在此背景下,以光…

【气动学】最优控制理论的归导定律和撞击角控制【含Matlab源码 14887期】含报告

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;Matlab武动乾坤博客之家&#x1f49e;…

零基础掌握cp2102与Modbus协议的工业通信对接

用一根USB线直连工业设备&#xff1f;揭秘CP2102与Modbus的硬核通信实战 你有没有遇到过这样的场景&#xff1a;手头有一台老式温控仪、一台支持RS-485的电表&#xff0c;或者一个老旧PLC&#xff0c;想读点数据出来做监控或调试——但你的笔记本根本没有串口。插上USB转TTL模…

如何高效部署专业翻译模型?HY-MT1.5-7B镜像一键启动指南

如何高效部署专业翻译模型&#xff1f;HY-MT1.5-7B镜像一键启动指南 在多语言内容爆发式增长的今天&#xff0c;高质量、低延迟的翻译服务已成为全球化应用的核心基础设施。腾讯开源的混元翻译模型 HY-MT1.5-7B 凭借其在 WMT25 夺冠的技术底座和对混合语言、术语干预等复杂场景…

AVD无法运行?一文说清Intel HAXM安装全流程

AVD启动失败&#xff1f;别急&#xff0c;彻底搞懂Intel HAXM安装与避坑全指南 你有没有遇到过这样的场景&#xff1a;刚装好Android Studio&#xff0c;信心满满地创建了一个AVD准备调试应用&#xff0c;结果一点运行&#xff0c;弹出一条红色错误提示&#xff1a; “Intel …

Neo4j中的Cypher查询优化技巧

在Neo4j数据库中,Cypher查询语言是进行数据操作的核心工具。然而,面对复杂的查询条件,如何有效地组织查询语句以避免性能瓶颈是每个开发者需要面对的问题。今天,我们将通过一个具体的例子来讨论如何优化Cypher查询。 背景介绍 假设我们有以下Neo4j数据库模型: Actor(演…

工业机器人通信前的USB转232驱动安装准备指南

工业机器人通信前的USB转232驱动安装实战指南在工业自动化现场&#xff0c;你是否曾遇到这样的场景&#xff1a;调试软件已经打开&#xff0c;串口参数全部配置完毕&#xff0c;可点击“连接”按钮后却始终收不到机器人的回应&#xff1f;检查线缆、重启控制器、反复插拔USB——…

一文说清电路仿真circuits网页版中的反馈电路原理

从零搞懂反馈电路&#xff1a;用网页仿真玩转负反馈与正反馈 你有没有试过搭一个放大电路&#xff0c;结果输出不是信号被削了顶&#xff0c;就是莫名其妙地“自己振起来”&#xff1f;又或者想做个方波发生器&#xff0c;可电路死活不起振&#xff1f; 这些问题的根源&#…

【图像隐写】快速四元数通用极坐标复指数变换的彩色图像零水印【含Matlab源码 14889期】

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;&#x1f49e;Matlab武动乾坤博客之家&#x1f49e;…

解决NumPy ImportError问题的实践与思考

背景介绍 在使用Python进行数据科学或数值计算时,NumPy是一个不可或缺的库。然而,在某些情况下,尝试导入NumPy可能会遇到各种错误,其中一种常见的问题是ImportError。本文将结合一个实际案例,探讨如何在Windows WSL2 Ubuntu环境中解决这一问题。 问题描述 假设你在一个…

CANFD协议仲裁场解析:核心要点说明

CAN FD仲裁场深度解析&#xff1a;从原理到实战的完整指南在一辆现代智能汽车中&#xff0c;成百上千个电子控制单元&#xff08;ECU&#xff09;需要通过车载网络实时交换数据。当刹车指令、雷达点云、发动机扭矩和OTA升级包同时争抢总线时&#xff0c;谁该优先通行&#xff1…

实战案例:基于车载雷达模块的CANFD与CAN对比

实战案例&#xff1a;车载毫米波雷达通信&#xff0c;为什么CANFD正在取代传统CAN&#xff1f;在一辆智能汽车的“神经系统”中&#xff0c;传感器是感知世界的“眼睛”和“耳朵”&#xff0c;而通信总线就是传递信息的“神经纤维”。当77GHz毫米波雷达每秒输出数百个目标点时&…