CuTe C++ 简介02,gemm_device cuda kernel 的实现

  

        《CuTe C++ 简介01,从示例开始 》 中,最后看到了 计算 gemm 的cuda kernel,使用 NVIDIA CUTLASS 的 CUTe (CUDA Tile) 库实现的高性能 GEMM (通用矩阵乘法) CUDA kernel。接下来解释一下这个内核的各个部分。文末再贴一遍代码,方便查看。

1. 模板参数和函数签名

template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>

这个内核高度模板化,支持配置,

        任意数据类型 (TA, TB, TC);

        任意矩阵形状和步长;

        任意内存布局和线程映射;

        不同的标量类型 (Alpha, Beta);

2. 静态断言和预条件检查

  大量的 CUTE_STATIC_ASSERT_V 和 static_assert 确保在编译时验证如下内容,

        正确的张量维度;

        线程布局和数据分块的兼容性;

        内存布局的一致性;

3. 张量创建和分块

Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA);
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});

        创建全局内存中的矩阵张量;

        使用 local_tile 将大矩阵分块为线程 block 处理的子块;

4. 共享内存分配

__shared__ TA smemA[cosize_v<ASmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);

        为矩阵 A 和 B 分配共享内存;

        使用模板化的布局确保内存访问的高效性;

5. 数据分区

Tensor tAgA = local_partition(gA, tA, threadIdx.x);
Tensor tAsA = local_partition(sA, tA, threadIdx.x);

        将全局内存和共享内存的数据分区给各个线程;

        每个线程负责加载特定的数据块;

6. 累加器分配和初始化

Tensor tCrC = make_tensor_like(tCgC);
clear(tCrC);

        为每个线程创建寄存器中的累加器;

        初始化为零;

7. 主循环 (核心计算部分)

for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{copy(tAgA(_,_,k_tile), tAsA);copy(tBgB(_,_,k_tile), tBsB);cp_async_fence();cp_async_wait<0>();__syncthreads();gemm(tCsA, tCsB, tCrC);__syncthreads();
}

这是内核的核心部分:

  1. 数据加载: 将全局内存中的数据拷贝到共享内存;

  2. 异步等待: 使用 cp_async 指令实现异步数据加载;

  3. 同步: 确保所有线程完成数据加载;

  4. 计算: 从共享内存加载数据并进行矩阵乘法计算;

  5. 再次同步: 确保所有线程完成计算;

8. 收尾处理

axpby(alpha, tCrC, beta, tCgC);

        应用缩放因子 alpha 和 beta;

        将计算结果写回全局内存;

关键点复盘

  1. 双缓冲机制: 通过循环处理 K 维度,实现计算和数据加载的重叠;

  2. 高效内存访问: 使用共享内存减少全局内存访问;

  3. 线程级并行: 精细的线程调度和数据分区;

  4. 模板元编程: 编译时优化,生成高度特化的代码;

  5. 异步拷贝: 使用 cp_async 指令隐藏内存延迟;

        评价:这个内核展示了现代 GPU 编程的最佳实践,通过精细的内存层次管理和线程调度来实现高性能的矩阵乘法运算。

template <class ProblemShape, class CtaTiler,class TA, class AStride, class ASmemLayout, class AThreadLayout,class TB, class BStride, class BSmemLayout, class BThreadLayout,class TC, class CStride, class CSmemLayout, class CThreadLayout,class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,Alpha alpha, Beta beta)
{using namespace cute;// PreconditionsCUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)static_assert(is_static<AThreadLayout>::value);static_assert(is_static<BThreadLayout>::value);static_assert(is_static<CThreadLayout>::value);CUTE_STATIC_ASSERT_V(size(tA) == size(tB));                          // NumThreadsCUTE_STATIC_ASSERT_V(size(tC) == size(tA));                          // NumThreadsCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{});  // BLK_N / THR_NCUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{});  // BLK_K / THR_KCUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{});  // BLK_M / THR_MCUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{});  // BLK_N / THR_Nstatic_assert(is_static<ASmemLayout>::value);static_assert(is_static<BSmemLayout>::value);static_assert(is_static<CSmemLayout>::value);CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_MCUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_NCUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_KCUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MKCUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NKCUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN//// Full and Tiled Tensors//// Represent the full tensorsTensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)// Get the appropriate blocks for this thread blockauto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)// Shared memory buffers__shared__ TA smemA[cosize_v<ASmemLayout>];__shared__ TB smemB[cosize_v<BSmemLayout>];Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K)//// Partition the copying of A and B tiles across the threads//// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tilesTensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)Tensor tAsA = local_partition(sA, tA, threadIdx.x);                  // (THR_M,THR_K)Tensor tBgB = local_partition(gB, tB, threadIdx.x);                  // (THR_N,THR_K,k)Tensor tBsB = local_partition(sB, tB, threadIdx.x);                  // (THR_N,THR_K)CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA));                // THR_MCUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA));                // THR_KCUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB));                // THR_K//// Define A/B partitioning and C accumulators//// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC// Partition sA (BLK_M, BLK_K) by the rows of tCTensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});   // (THR_M,BLK_K)// Partition sB (BLK_N, BLK_K) by the cols of tCTensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{});   // (THR_N,BLK_K)// Partition gC (M,N) by the tile of tCTensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});   // (THR_M,THR_N)// Allocate the accumulators -- same shape/layout as the partitioned dataTensor tCrC = make_tensor_like(tCgC);                                // (THR_M,THR_N)CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC));                // THR_MCUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA));                // THR_MCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB));                // THR_NCUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB));                // BLK_K// Clear the accumulatorsclear(tCrC);#if 0if(thread0()) {print("  mA : "); print(  mA); print("\n");print("  gA : "); print(  gA); print("\n");print("  sA : "); print(  sA); print("\n");print("tAgA : "); print(tAgA); print("\n");print("tAsA : "); print(tAsA); print("\n");}
#endif#if 0if(thread0()) {print("  mB : "); print(  mB); print("\n");print("  gB : "); print(  gB); print("\n");print("  sB : "); print(  sB); print("\n");print("tBgB : "); print(tBgB); print("\n");print("tBsB : "); print(tBsB); print("\n");}
#endif#if 0if(thread0()) {print("  mC : "); print(  mC); print("\n");print("  gC : "); print(  gC); print("\n");print("tCsA : "); print(tCsA); print("\n");print("tCsB : "); print(tCsB); print("\n");print("tCgC : "); print(tCgC); print("\n");print("tCrC : "); print(tCrC); print("\n");}
#endif#if 1// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,//           and then computes on those tiles.//   copy(.) operates on the global and shared memory via the tA|tB partitioning//   gemm(.) operates on the shared and register memory via the tC partitioningauto K_TILE_MAX = size<2>(tAgA);for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile){// Copy gmem to smem with tA|tB thread-partitioned tensorscopy(tAgA(_,_,k_tile), tAsA);      // A   (THR_M,THR_K) -> (THR_M,THR_K)copy(tBgB(_,_,k_tile), tBsB);      // B   (THR_N,THR_K) -> (THR_N,THR_K)// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to//   Tensor tAgAk = tAgA(_,_,k_tile);//   CUTE_UNROLL//   for (int i = 0; i < size(tAsA); ++i) {//     tAsA(i) = tAgAk(i);//   }cp_async_fence();        // Label the end of (potential) cp.async instructionscp_async_wait<0>();      // Sync on all (potential) cp.async instructions__syncthreads();         // Wait for all threads to write to smem// Compute gemm on tC thread-partitioned smemgemm(tCsA, tCsB, tCrC);            // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to//   CUTE_UNROLL//   for (int k = 0; k < size<1>(tCsA); ++k) {//     CUTE_UNROLL//     for (int m = 0; m < size<0>(tCrC); ++m) {//       CUTE_UNROLL//       for (int n = 0; n < size<1>(tCrC); ++n) {//         tCrC(m,n) += tCsA(m,k) * tCsB(n,k);//       }//     }//   }__syncthreads();         // Wait for all threads to read from smem}#endif//// Epilogue//axpby(alpha, tCrC, beta, tCgC);// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to//   CUTE_UNROLL//   for (int i = 0; i < size(tCrC); ++i) {//     tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);//   }
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/95922.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/95922.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

万代《宝可梦》主题新品扭蛋公开!史上最大尺寸

使用jQuery的常用方法与返回值分析 jQuery是一个轻量级的JavaScript库&#xff0c;旨在简化HTML文档遍历和操作、事件处理以及动画效果的创建。本文将介绍一些常用的jQuery方法及其返回值&#xff0c;帮助开发者更好地理解和运用这一强大的库。 1. 选择器方法 jQuery提供了多种…

【FastDDS】Layer Transport ( 05-Shared Memory Transport)

6.4 共享内存传输 共享内存&#xff08;SHM&#xff09;传输依靠主机操作系统提供的共享内存机制&#xff0c;实现了在同一处理单元/机器上运行的实体之间的快速通信。注意 Fast DDS 利用域参与者&#xff08;DomainParticipant&#xff09;的 GuidPrefix_t 来识别在同一主机上…

记 2025/9/6

人工智能常见的模型按照处理问题分为6大类&#xff1a;处理权重问题的权重模型、处理状态问题的状态模型、处理序列问题的问题模型、处理表示问题的表示模型、处理相似度的相似模型、处理分类问题的分类模型。权重是计算特定状态下事物的重要性。状态问题是刻画权重动态变化的过…

开启Python之路,第一节学习大纲-从入门到进阶

前端开启Python之路&#xff0c;前端有没有必要卷后端技术&#xff0c;欢迎各位大神批评指正 第一阶段&#xff1a;基础入门 (打好根基) 目标&#xff1a; 理解编程基本概念&#xff0c;掌握 Python 核心语法&#xff0c;能编写简单的脚本程序。 1、环境搭建与开发工具 安装 Py…

webshell及冰蝎双击无法打开?

什么是webshell&#xff1f; web:万维网 shell&#xff1a;是指一种应用程序&#xff0c;为用户和系统之间建立连接&#xff0c;通过这个界面访问操作系统内核的服务 webshell:是以asp、aspx、php、jsp或者cgi等网页文件形式存在的一种命令执行环境&#xff0c;也可以将其称做…

【星闪】Hi2821 | PWM脉宽调制模块 + 呼吸灯例程

1. 简介PWM&#xff08;Pulse Width Modulation&#xff09;&#xff0c;全称脉宽调制&#xff0c;通过对一系列脉冲的宽度进行调制&#xff0c;等效出所需波形。即对模拟信号电平进行数字编码&#xff0c;通过调节频率、占空比的变化来调节信号的变化。一个 PWM 周期内由一段高…

51单片机---硬件学习(电子琴、主从应答模式、modbus模型、DS18B20传感器显示温度)

一、串行通信与并行通信1、串行通信定义&#xff1a;数据一位一位地按顺序通过单条传输线进行传输的通信方式。优点&#xff1a;传输线少&#xff0c;成本低&#xff0c;适合长距离传输缺点&#xff1a;传输速度相对较慢2、并行通信定义&#xff1a;数据的各位同时通过多条并行…

SpringBoot后端开发常用工具详细介绍——SpringSecurity认证用户保证安全

简单的开始 创建SpringBoot项目 首先创建一个简单的springboot项目&#xff0c;假设端口为8888&#xff0c;添加controller控制层&#xff0c;并在其中添加TestController控制类&#xff0c;那么启动springboot项目之后&#xff0c;访localhost:8888/api/message页面会显示my…

别再手工缝合API了!开源LLMOps神器LMForge,让你像搭积木一样玩转AI智能体!

你是否受够了这些&#xff1f; 刚调通OpenAI的API&#xff0c;老板说“咱们试试国产模型降本增效”&#xff0c;你看着满屏的if-else只想说“我晕”。想给AI加上“查天气”、“执行代码”的能力&#xff0c;却发现Function Calling的代码复杂得让人头皮发麻。本地的Agentdemo惊…

window使用ffmep工具,加自定义脚本执行视频转码成h264(运营人员使用)

技术文章大纲&#xff1a;ffmep配合脚本使用1. 需要提供脚本给视频转码的给运营,给运营上传视频使用安装ffmep windows版本(目前我使用的就是windows)将脚本里面的执行路径修改成自己的电脑安装ffmep/bin/ffmep.exe路径处理好之后就点击执行2.环境准备ffmep windows版解压到一个…

Leetcode 240. 搜索二维矩阵 II 矩阵 / 二分

原题链接&#xff1a; Leetcode 240. 搜索二维矩阵 II 解法一&#xff1a;排除法 参考 【图解】排除法&#xff0c;一图秒懂&#xff01;&#xff08;Python/Java/C/C/Go/JS/Rust&#xff09; 从右上角&#xff1a; class Solution { public:bool searchMatrix(vector<vec…

OCR 证件识别:驱动澳门酒店自助入住智能化

澳门酒店作为国际旅游窗口&#xff0c;每日接待持多元证件的旅客&#xff0c;OCR 证件识别技术的应用&#xff0c;让自助入住终端实现 “一证通办”&#xff0c;大幅提升服务效率。​旅客在自助终端办理入住时&#xff0c;只需将护照、港澳通行证、回乡证、电子身份证等证件贴近…

深入解析汇编语言的奥秘

汇编语言简介汇编语言&#xff08;Assembly Language&#xff09;是一种低级编程语言&#xff0c;直接对应计算机的机器指令集。它通过助记符&#xff08;如 MOV、ADD&#xff09;代替二进制操作码&#xff0c;更接近硬件架构&#xff0c;常用于性能优化、嵌入式开发或逆向工程…

Nextcloud 实战:打造属于你的私有云与在线协作平台

随着数据安全与隐私保护意识的提升&#xff0c;越来越多的个人和组织选择自建云平台来替代公有云。Nextcloud 作为一款开源的文件同步与协作套件&#xff0c;不仅能实现类似网盘的文件存储与分享&#xff0c;还提供日历、联系人、即时通讯、在线文档编辑等协作功能&#xff0c;…

实践指南:利用衡石AI Data Agent实现自然语言驱动的指标开发与归因

在数字化转型的深水区&#xff0c;企业数据团队常面临两难困境&#xff1a;业务部门需要敏捷响应的指标分析&#xff0c;但传统BI工具依赖技术团队编写SQL&#xff0c;导致需求交付周期长达数周&#xff1b;而直接暴露底层数据又存在安全与合规风险。衡石科技推出的AI Data Age…

知微集:Python中的线程(三)

欢迎来到"一起学点什么吧"的合集「NLP知微集」。在这里&#xff0c;我们不愿宏大叙事&#xff0c;只聚焦于自然语言处理领域中那些细微却关键的“齿轮”与“螺丝钉”。我相信&#xff0c;真正深刻的理解&#xff0c;源于对细节的洞察。本期&#xff0c;我将为您拆解的…

动态规划入门:从记忆化搜索到动态规划

在开始对动态规划的讲解之前&#xff0c;我们需要先对记忆化搜索进行回顾&#xff1a; 什么是记忆化搜索&#xff1f; 在搜索过程中&#xff0c;当搜索树中存在大量重复的节点时&#xff0c;我们可以通过引入一个"备忘录"&#xff08;通常是一个数组或哈希表&#…

Boost搜索引擎 网络库与前端(4)

文章目录前言一、引入网络库模块引入cpp-httplibcpp-httplib测试正式编写http_server二、前端模块三、项目的可能拓展总结前言 终于到了最后一篇喽&#xff0c;嘻嘻&#xff01; 一、引入网络库模块 引入cpp-httplib 下载地址如下&#xff0c;我个人不喜欢新版本   cpp-http…

Flink反压问题

背景在使用flink的过程中&#xff0c;多次遇到过反压&#xff08;backpressure&#xff09;的问题&#xff0c;这通常是因为数据处理的速率超过了数据源或下游系统的处理能力导致。反压的底层剖析网络流控一个重要的概念是网络流控&#xff0c;如上图&#xff0c;不同的Consume…

Day5-中间件与请求处理

昨天搞定了异步优化&#xff0c;今天来解决一些实际问题。Day4的API虽然性能不错&#xff0c;但还缺少一些企业级应用必备的功能。 现在的问题 前端无法访问API&#xff08;跨域问题&#xff09;没有请求日志&#xff0c;出问题难以排查错误信息格式不统一缺少统一的请求处理机…