高性能计算-CUDA-mma-PTX

news/2025/11/8 14:59:13/文章来源:https://www.cnblogs.com/anluo8/p/19201905

1. 简介

  • 用 mma PTX 指令实现 M16N16K16 矩阵乘法

2. 代码

  • 调用1:wmma + sharedM
  • 调用2:wmma + sharedM + padding 避免 bankcoflict
  • 调用3:mma + sharedM + swizzle 避免 bankcoflict
//A 16*16; B 16*16
//wmma 处理 half 使用 16*16*16 size 的 matrix,并使用padding 优化
//mma swizzle#include <iostream>
#include <cuda_runtime.h>#include "common/tester.h"
#include "common/common.h"using namespace nvcuda;__device__ __forceinline__ void ld_st_128bit(void *dst, void *src)
{*reinterpret_cast<float4 *>(dst) = *reinterpret_cast<float4 *>(src);
}//wmma + sharedM
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_wmma_kernel(half *A, half *B, half *C)
{__shared__ half smem_a[M * K];__shared__ half smem_b[K * N];__shared__ half smem_c[M * N];int tx = threadIdx.x;uint32_t nPerThreadLoad = M*K/32; //8ld_st_128bit(smem_a + nPerThreadLoad * tx, A + nPerThreadLoad * tx);ld_st_128bit(smem_b + nPerThreadLoad * tx, B + nPerThreadLoad * tx);__syncthreads();wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;wmma::fill_fragment(c_frag, 0.0f);//load_matrix_sync 底层使用 ldmatrix ptx指令的加载分块矩阵的方式,所以bankConflict 分析可以参考 mma ldmatrix 的加载数据方式wmma::load_matrix_sync(a_frag, smem_a, K);  //ldmwmma::load_matrix_sync(b_frag, smem_b, N);wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);wmma::store_matrix_sync(smem_c, c_frag, N, wmma::mem_row_major);__syncthreads();ld_st_128bit(C + nPerThreadLoad * tx, smem_c + nPerThreadLoad * tx);
}//优化1:wmma + sharedM + padding 避免 bankcoflict
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_padding_wmma_kernel(half *A,half *B,half *C)
{// +8 OFFSET 也可以是别的参数const uint32_t OFFSET = 8;__shared__ half smem_a[M][K+OFFSET];__shared__ half smem_b[K][N+OFFSET];__shared__ half smem_c[M*N];uint32_t tx = threadIdx.x;uint32_t nPerThreadLoad = M*K/32;   //8 128bitld_st_128bit(&smem_a[tx/2][tx%2*nPerThreadLoad],A+tx*nPerThreadLoad);ld_st_128bit(&smem_b[tx/2][tx%2*nPerThreadLoad],B+tx*nPerThreadLoad);__syncthreads();wmma::fragment<wmma::matrix_a, M, N, K, half,wmma::row_major> a_frag;wmma::fragment<wmma::matrix_b, M, N, K, half,wmma::row_major> b_frag;wmma::fragment<wmma::accumulator,M,N,K,half> c_frag;wmma::fill_fragment(c_frag,0.0f);wmma::load_matrix_sync(a_frag,(half*)smem_a,K+OFFSET);   wmma::load_matrix_sync(b_frag,(half*)smem_b,N+OFFSET);wmma::mma_sync(c_frag,a_frag,b_frag,c_frag);wmma::store_matrix_sync(smem_c,c_frag,N,wmma::mem_row_major);__syncthreads();ld_st_128bit(C + tx*nPerThreadLoad,smem_c + nPerThreadLoad*tx);
}// mma #define REG(val) (*reinterpret_cast<uint32_t*>(&(val)))
#define HALF2(val) (*reinterpret_cast<half2*>(&(val)))
//ptx 
__device__ __forceinline__ void ldmatrix_sync(half *dst,half *src)
{//sm_90之后支持// "=r"约束符用于‌输出操作数,= 符号表示这是一个只写操作数(输出操作数),r 表示操作数应该放在通用寄存器中// "l"约束符用于‌输入操作数‌,表示该操作数用于提供地址信息asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];": "=r"(REG(dst[0])),"=r"(REG(dst[2])),"=r"(REG(dst[4])),"=r"(REG(dst[6])): "l"(__cvta_generic_to_shared(src)));
}__device__ __forceinline__ void ldmatrix_trans_sync(half *dst, void *src)
{//LD.trans trans frament 块内 N格式,寄存器按列存储读取。不改变原矩阵的数据排布,只改变寄存器的读写方向。asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.trans.b16 {%0, %1, %2, %3}, [%4];": "=r"(REG(dst[0])),"=r"(REG(dst[2])),"=r"(REG(dst[4])),"=r"(REG(dst[6])): "l"(__cvta_generic_to_shared(src)));
}__device__ __forceinline__ void mma_sync_m16n8k16(half *c, half *a, half *b)
{asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 ""{%0, %1}, ""{%2, %3, %4, %5}, ""{%6, %7}, ""{%8, %9};": "=r"(REG(c[0])), "=r"(REG(c[2])): "r"(REG(a[0])),"r"(REG(a[2])),"r"(REG(a[4])),"r"(REG(a[6])),"r"(REG(b[0])),"r"(REG(b[2])),"r"(0),"r"(0));
}__device__ __forceinline__ void stmatrix_sync(half *dst, half *src)
{//sm_100 laterasm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};": // 无输出操作数,它从C/C++代码的角度看,是‌消耗了输入操作数‌(即寄存器%1和%2中的数据),整个操作过程并没有产生一个可供C/C++代码使用的“返回值”或“输出变量”,因此被声明为“无输出操作数”: "l"(__cvta_generic_to_shared(dst)),"r"(REG(src[0])),"r"(REG(src[2])),"r"(REG(src[4])),        "r"(REG(src[6])): "memory"  //“memory”字段属于clobber列表的一部分,它告知编译器该汇编指令可能会修改内存内容,使用"memory"可以确保编译器不会缓存内存中的旧值,从而保证后续操作能读取到最新的数据); 
}__device__ __forceinline__ void stmatrix_sync_(half *dst, half *src)
{//for sm_100 之前的版本// ! Ampere doesn't have stmatrix.sync, we should simulate ituint64_t private_addr = (uint64_t)dst;uint64_t shared_addr[4];
#pragma unrollfor (int i = 0; i < 4; i++){//广播 i * 8 + threadIdx.x / 4 通道的 private_addr值shared_addr[i] = __shfl_sync(0xFFFFFFFF, private_addr, i * 8 + threadIdx.x / 4);}
#pragma unrollfor (int i = 0; i < 4; i++){*(reinterpret_cast<half2 *>(shared_addr[i]) + threadIdx.x % 4) = HALF2(src[2 * i]);}
}// mma + sharedM + fragment的寄存器存储
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_mma_kernel(half *A, half *B, half *C)
{__shared__ half smem_a[M * K];__shared__ half smem_b[K * N];__shared__ half smem_c[M * N];int tx = threadIdx.x;uint32_t nPerThreadLoad = M*K/32; //8//共享内存数据与原数据排布一致,当做矩阵的排布ld_st_128bit(smem_a + nPerThreadLoad * tx, A + nPerThreadLoad * tx);ld_st_128bit(smem_b + nPerThreadLoad * tx, B + nPerThreadLoad * tx);__syncthreads();wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;wmma::fill_fragment(c_frag, 0.0f);//共享内存数据转存到 fragment 寄存器中,排布不变// mma 要求 A按行存储;B(因为是16*8需要trans)输入应按列存储//此处 row col 的计算方法是将原共享内存矩阵数组转化为 fragment 存储格式的算法,与行列互换无关uint32_t row = tx%K;uint32_t col = tx/K;//共享内存地址固定写法// fragment.x 指向寄存器数组,第一个寄存器为 fragment.x[0];// 这里各个线程统一传入 fragment.x,底层可能另有处理ldmatrix_sync(a_frag.x, smem_a + row*K + col*8);    ldmatrix_trans_sync(b_frag.x, smem_b + row*K + col*8 );#if 0 //fragment_.num_storage_elements:32if(tx ==0)printf("size: %ld, addr: %p\n",sizeof(a_frag.x),&REG(a_frag.x[0]));if(tx ==1)printf("size: %ld, addr: %p\n",sizeof(a_frag.x),&REG(a_frag.x[1]));#endifmma_sync_m16n8k16(c_frag.x, a_frag.x, b_frag.x);// 偏移量4 = 4个寄存器 × 2元素/寄存器 = 8个FP16元素 mma_sync_m16n8k16(c_frag.x+4, a_frag.x, b_frag.x+4);stmatrix_sync(smem_c + row*K + col*nPerThreadLoad,c_frag.x);__syncthreads();ld_st_128bit(C + nPerThreadLoad * tx, smem_c + nPerThreadLoad * tx);
}// mma + sharedM + swizzle 避免 bankcoflict
/*
原数据地址与共享内存地址转换
对于gaddr 在共享内存的原地址addr: B 表示列需要的二进制位数,M 表示一个块内的元素索引需要的二进制位数,S 表示 addr 地址按块划分的行列坐标需要位移的二进制位数=M
*/
template<uint32_t B,uint32_t M,uint32_t S>
__device__ __forceinline__ uint32_t swizzle(uint32_t srcAddr)
{//行列坐标值取后三位进行异或运算//掩码用来获取行坐标uint32_t mask = (1 << B - 1) << M;uint32_t addr = ((srcAddr >> S) & mask) ^ srcAddr;return addr;
}template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_mma_swizzle_kernel(half *A, half *B, half *C)
{__shared__ half smem_a[M * K];__shared__ half smem_b[K * N];__shared__ half smem_c[M * N];int tx = threadIdx.x;uint32_t nPerThreadLoad = M*K/32; //8uint32_t offset = tx * nPerThreadLoad;//根据线程在全局内存中取数据逻辑地址计算共享内存物理地址存放数据//16行 3位表示;2列 1位表示;一块8个元素 3位表示uint32_t g2sAddr = swizzle<3,1,3>(offset);ld_st_128bit(smem_a + g2sAddr, A+offset);ld_st_128bit(smem_b + g2sAddr, B+offset);__syncthreads();wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;wmma::fill_fragment(c_frag, 0.0f);//从共享内存取出数据,计算线程原本要加载数据逻辑坐标,通过 swizzle 计算物理坐标//保持原矩阵布局不变的 row col 计算uint32_t row = tx%16;uint32_t col = tx/16;//根据当前线程提供数据地址逻辑计算在共享内存物理地址uint32_t s2rAddr = swizzle<3,1,3>(row*M+col*nPerThreadLoad);ldmatrix_sync(a_frag.x,smem_a+s2rAddr);ldmatrix_trans_sync(b_frag.x,smem_b+s2rAddr);#if 1   //使用 wmma api计算wmma::mma_sync(c_frag,a_frag,b_frag,c_frag);#else   //使用mma ptx 指令计算mma_sync_m16n8k16(c_frag.x,a_frag.x,b_frag.x);mma_sync_m16n8k16(c_frag.x+4,a_frag.x,b_frag.x+4);#endifstmatrix_sync(smem_c+ s2rAddr,c_frag.x);//从共享内存取数据,逻辑地址为 tx*nPerThreadLoad,物理地址为 g2sAddrld_st_128bit(C+tx*nPerThreadLoad,smem_c+g2sAddr);
}//以下为调用
void sharedM_wmma(half *A, half *B, half *C, int M, int N, int K)
{const int WMMA_M = 16;const int WMMA_N = 16;const int WMMA_K = 16;dim3 block(32);dim3 grid(1);sharedM_wmma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}void sharedM_padding_wmma(half *A, half *B, half *C, int M, int N, int K)
{const int WMMA_M = 16;const int WMMA_N = 16;const int WMMA_K = 16;dim3 block(32);dim3 grid(1);sharedM_padding_wmma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}void sharedM_mma(half *A, half *B, half *C, int M, int N, int K)
{const int WMMA_M = 16;const int WMMA_N = 16;const int WMMA_K = 16;dim3 block(32);dim3 grid(1);sharedM_mma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}void sharedM_mma_swizzle(half *A, half *B, half *C, int M, int N, int K)
{const int WMMA_M = 16;const int WMMA_N = 16;const int WMMA_K = 16;dim3 block(32);dim3 grid(1);sharedM_mma_swizzle_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}
int main(int argc, char *argv[])
{// {//     Tester tester(16, 16, 16, 1, 1, 100, true);//     tester.evaluate(shareM_wmma, "sharedM_wmma");// }// {//     Tester tester(16, 16, 16, 1, 1, 100, true);//     tester.evaluate(sharedM_padding_wmma, "sharedM_padding_wmma");// }// {//     Tester tester(16, 16, 16, 1, 1, 100, true);//     tester.evaluate(sharedM_mma, "sharedM_mma");// }{Tester tester(16, 16, 16, 1, 1, 100, true);tester.evaluate(sharedM_mma_swizzle, "sharedM_mma_swizzle");}return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/959771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年口碑好的GEO(AI搜索优)服务商解析与推荐

文章摘要 本文深入解析2025年GEO(AI搜索优)服务商的市场现状,重点推荐口碑优秀的服务商如摘星AI。内容涵盖服务商选择标准、行业趋势分析,并提供数据支持的比较,帮助用户做出明智决策。文章基于权威行业报告,旨在为…

2025年手机壳厂家革新包装技术:离心式包装机深度解析

文章摘要 本文探讨2025年手机壳行业包装技术趋势,重点解析离心式包装机在提升效率、降低成本方面的优势。基于合肥摘星人工智能应用软件有限公司的经验,分享如何通过智能包装解决方案优化手机壳生产流程,覆盖湖南省…

2025年广州工商注册公司权威推荐榜单:税务股权架构方案/工商变更/工商注销源头公司精选

在广州这座创业活力之都,每天都有大量市场主体诞生。据《2025年广州市中小微企业服务市场发展白皮书》显示,广州作为大湾区核心引擎,2024年新增市场主体超过30万户。然而,“创业第一步”——公司注册正变得日益复杂…

51单片机使用TM1638驱动的数码管键盘模块

51单片机使用TM1638驱动的数码管键盘模块带k的都是可以按键扫描的,SEG和GR是数码管段和位,STB,CLK,DIO是与数据相关的引脚数据手册有说,不管芯片连接的是共阳极数码管还是共阴极数码管,SEG都必须接阳极,GR接阴极,…

2025年专业办公空间装修公司排行

摘要 随着企业对于办公环境需求的不断提升,办公空间装修行业在2025年呈现出智能化、环保化、个性化的发展趋势。本文基于市场调研和用户口碑,整理了目前行业内前十的办公空间装修公司推荐榜单,旨在为企业主提供参考…

记一次 float64 排序失效的灵异事件

某一天的下午,我手头没什么事情,双眼迷离,正左手托着下巴空洞地盯着屏幕发呆。恍惚间,BUG反馈群冷不丁冒了消息,我定下神来看,测试同学反馈了一个排行榜的排序问题,排行榜中相同分数的玩家,后达到分数的反而排…

完整教程:TypeScript 面试题及详细答案 100题 (21-30)-- 接口(Interface)

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

详细介绍:SkyDiffusion:用 BEV 视角打开街景→航拍图像合成新范式

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Blender中如何让导出的FBX模型文件同时携带多个动画片段

Blender版本:V4.5; Unity版本:团结引擎 V1.7.3; 问题描述: 起因是博主本人最近在学习Unity,使用到了Blender对3D模型进行动画片段制作,但是博主在学习过程中发现,我使用Blender导出的FBX文件一次只能携带一个动…

精美的vue流程设计器

一、vue-dawn-flow介绍 vue-dawn-flow是一款功能强大的开源流程设计器,专为 Vue.js 生态打造,完美兼容 Vue 2 和 Vue 3 框架。并且能很好的兼容vue前端所有框架。 1.1插件功能提供了一个可视化的流程设计器,你可以在…

2025年刀轮船订制厂家权威推荐榜单:斗轮清淤船/刀轮式挖泥船/小型斗轮船源头厂家精选

在内河航道维护与水利工程建设领域,刀轮船作为高效清淤装备,其作业效率直接影响工程进度与成本。据水利行业统计数据显示,2025年我国内河清淤市场规模预计达到287亿元,年增长率稳定在8%-12%。 刀轮船凭借其独特的斗…

高效地使用std::map

#include <iostream> #include <string> #include <map> using namespace std;typedef map<string, int> M; M m; const char K[] = "key";void fn1 () {auto p = m.insert({K, 0…

flask:得到get/post参数

一,得到get参数 代码: from flask import Blueprint,jsonify,render_template,requestuser = Blueprint(user, __name__)# 用蓝图注册路由 @user.route("/add/") def user_add():# 得到get参数name = requ…

YACS2025年10月甲组

YACS2025年10月甲组T1. 数据结构 注意到可以离线,考虑整体二分。每次执行前一半操作,如果发现超过了 \(y\),那么答案就在前一半操作,否则就在后一半操作(如果补一个操作编号为 \(0\),整体加极大值的操作)。 所以…

2025年peek什么材料定制厂家权威推荐榜单:peek原料/材料peek/peek塑料原料源头厂家精选

在机器人轻量化与新能源汽车爆发式增长的浪潮下,PEEK(聚醚醚酮)材料凭借其卓越性能正成为高端制造领域的“新宠”。据行业数据显示,特斯拉Optimus单机使用PEEK量超过2kg,预计2025年全球出货量达50万台时,将激发出…

一对一视频聊天源码,高效查找方法之二分查找 - 云豹科技

一对一视频聊天源码,高效查找方法之二分查找介绍二分查找也称折半查找(Binary Search),它是一种效率较高的查找方法。但是,折半查找要求线性表必须采用顺序存储结构,而且表中元素按关键字有序排列。过程首先,假…

2025年高解析喷码机生产厂家权威推荐榜单:打标机/打码机/工业喷码机源头厂家精选

在“一物一码”成为食品、医药、线缆、日化等行业出厂标配的2025年,高解析喷码机已成为产品追溯、品牌防伪及生产管理不可或缺的一环。 高解析喷码技术正随着工业4.0深化与"中国智造"转型而持续进步。据行业…

Netty 示例

1. Netty 示例 1.1. 简单的 Echo 服务器 这里,我们直接使用Netty作为独立的进程启动 1.1.1. Netty 依赖 maven依赖如下: <dependency><groupId>io.netty</groupId><artifactId>netty-all<…

2025年电子压力试验生产厂家权威推荐榜单:混凝土压力试验机/纸箱压力试验机/全自动压力试验机源头厂家精选

在工程质量检测与材料研发领域,电子压力试验机作为衡量材料力学性能的关键设备,其测量精度与稳定性直接影响检测结果的可靠性。据行业报告显示,全球压力试验机市场正稳步增长,技术创新与智能化成为推动行业发展的核…

从网络下载图片到本地

/// <summary> /// 保存图片从web /// </summary> /// <param name="imgUrl">图片网页链接</param> /// <param name="path">保存路径</param> /// <para…