精通 triton 使用 MLIR 的源码逻辑 - 第001节:triton 的应用简介

        项目使用到 MLIR,通过了解 triton 对 MLIR 的使用,体会到 MLIR 在较大项目中的使用方式,汇总一下。

1. Triton 概述

        OpenAI Triton 是一个开源的编程语言和编译器,旨在简化 GPU 高性能计算(HPC) 的开发,特别是针对深度学习、科学计算等需要高效并行计算的领域。
既允许开发者编写高度优化的代码,又不必过度关注底层硬件细节。这样,通过简化高性能计算,可以加速新算法的实现和实验。传统 GPU 编程(如 CUDA)需要深入理解硬件架构和复杂的优化技术,而 Triton 旨在提供更高层次的抽象,降低开发门槛,但是设计 triton 语言及其编译器本身,门槛却非常高。


Triton 是基于 Python 的 DSL(领域特定语言),Triton 提供类似 Python 的语法,允许用户用简洁的代码表达并行计算逻辑,然后通过编译器优化为高效的 GPU 代码。其中,这些优化是自动化的。自动处理线程调度、内存合并(memory coalescing)、共享内存分配等底层优化,减少手动调优的工作量。Triton 在模块化与可扩展性方面下了不少功夫,它支持用户自定义内核(kernels)和优化策略,同时提供标准化的高性能算子库(如矩阵乘法、卷积等)。同时,Triton 可与 PyTorch 等深度学习框架集成,支持直接调用 Triton 内核。


在理念上,Triton 使用多级并行计算模型,借鉴 CUDA 的线程层次(thread blocks/grids),但通过更高层次的抽象(如 triton.program_id)简化编程。针对数据的局部性做优化,自动利用 GPU 的共享内存(shared memory)和寄存器,优化内存访问模式。Triton 把 LLVM 编译框架融合了进来,Triton 编译器将高级代码转换为优化的 PTX(NVIDIA GPU 的中间表示),同时结合了机器学习驱动的自动调优(auto-tuning)。在其前端,Triton 借助形式化程序语义,通过静态分析和程序变换确保代码的正确性和性能可预测性。

2. 基于预编译的包安装 triton

triton 通常跟 pytorch 一起使用;

2.1 安装 pytorch

安装一个基于 cuda 12.8 的 pytorch:

$ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

需要下载 几个 GB 的包,网络好的话会比较快,或者下班前、睡觉前安装;

验证安装:

2.2 安装triton

pip install triton

验证安装: 跑一个 tutorial 01:

$ wget https://triton-lang.org/main/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip
$ unzip ........$ python ./01-vector-add.py

运行结果应该如下:

3.  通过 example 了解 triton

3.1 01-vector-add.py 的源码

"""
Vector Addition
===============In this tutorial, you will write a simple vector addition using Triton.In doing so, you will learn about:* The basic programming model of Triton.* The `triton.jit` decorator, which is used to define Triton kernels.* The best practices for validating and benchmarking your custom ops against native reference implementations."""# %%
# Compute Kernel
# --------------import torchimport triton
import triton.language as tlDEVICE = triton.runtime.driver.active.get_active_torch_device()@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.y_ptr,  # *Pointer* to second input vector.output_ptr,  # *Pointer* to output vector.n_elements,  # Size of the vector.BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.# NOTE: `constexpr` so it can be used as a shape value.):# There are multiple 'programs' processing different data. We identify which program# we are here:pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.# This program will process inputs that are offset from the initial data.# For instance, if you had a vector of length 256 and block_size of 64, the programs# would each access the elements [0:64, 64:128, 128:192, 192:256].# Note that offsets is a list of pointers:block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# Create a mask to guard memory operations against out-of-bounds accesses.mask = offsets < n_elements# Load x and y from DRAM, masking out any extra elements in case the input is not a# multiple of the block size.x = tl.load(x_ptr + offsets, mask=mask)y = tl.load(y_ptr + offsets, mask=mask)output = x + y# Write x + y back to DRAM.tl.store(output_ptr + offsets, output, mask=mask)# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes:def add(x: torch.Tensor, y: torch.Tensor):# We need to preallocate the output.output = torch.empty_like(x)assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICEn_elements = output.numel()# The SPMD launch grid denotes the number of kernel instances that run in parallel.# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].# In this case, we use a 1D grid where the size is the number of blocks:grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )# NOTE:#  - Each torch.tensor object is implicitly converted into a pointer to its first element.#  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.#  - Don't forget to pass meta-parameters as keywords arguments.add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still# running asynchronously at this point.return output# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is 'f'{torch.max(torch.abs(output_torch - output_triton))}')# %%
# Seems like we're good to go!# %%
# Benchmark
# ---------
#
# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops.
# for different problem sizes.@triton.testing.perf_report(triton.testing.Benchmark(x_names=['size'],  # Argument names to use as an x-axis for the plot.x_vals=[2**i for i in range(12, 28, 1)],  # Different possible values for `x_name`.x_log=True,  # x axis is logarithmic.line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.line_vals=['triton', 'torch'],  # Possible values for `line_arg`.line_names=['Triton', 'Torch'],  # Label name for the lines.styles=[('blue', '-'), ('green', '-')],  # Line styles.ylabel='GB/s',  # Label name for the y-axis.plot_name='vector-add-performance',  # Name for the plot. Used also as a file name for saving the plot.args={},  # Values for function arguments not in `x_names` and `y_name`.))
def benchmark(size, provider):x = torch.rand(size, device=DEVICE, dtype=torch.float32)y = torch.rand(size, device=DEVICE, dtype=torch.float32)quantiles = [0.5, 0.2, 0.8]if provider == 'torch':ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)return gbps(ms), gbps(max_ms), gbps(min_ms)# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data:
benchmark.run(print_data=True, show_plots=True)

3.2 01-vector-add.py 源码分析

    业务逻辑从 Line: 86 开始:torch.manual_seed(0)

首先,设置随机函数的种子;

接着,定义了两个一维的 tensor 变量 x 和 y,并随机了其元素的值;

然后,使用 pytorch 的 + 算符计算了两个 tensor 的逐元素和: output_torch = x + y;

接下来,调用自定义 add 函数,使用 triton kernel 计算了两个 tensor 的逐元素和。

从 add 函数开始逐行注释一下:

@triton.jit
def add_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE: tl.constexpr,):pid = tl.program_id(axis=0)# 相当于 cuda 中 blockId.x,axis=0 是指 x方向block_start = pid * BLOCK_SIZE#当前block 在获取数据时的起始偏移offsets = block_start + tl.arange(0, BLOCK_SIZE)#本 block 覆盖的偏移范围mask = offsets < n_elements#offsets 的范围中,其值小于 n_el... 的话,mask 为true,否则为faulsex = tl.load(x_ptr + offsets, mask=mask)# mask 为true的话,取值y = tl.load(y_ptr + offsets, mask=mask)output = x + y#相加tl.store(output_ptr + offsets, output, mask=mask)#mask 为 true的话,存回 DRAMdef add(x: torch.Tensor, y: torch.Tensor):output = torch.empty_like(x)# 定义一个shape 跟x一样的tensor 变量。# 接下来检查 x,y,output 躺在的设备是否相同。assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE# 获取 output 这个 tensor 的元素个数,存在 n_elements 中。n_elements = output.numel()# 接下来两行代码将在正文中做一些解释:grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)return output

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

逐条说明这句的要件:
这是一个动态计算网格大小的 lambda 函数
meta 参数是一个字典,包含内核的编译时常量(这里是 BLOCK_SIZE)
triton.cdiv 是 Triton 提供的向上取整除法函数,确保所有元素都被处理
grid 计算结果是一个元组,表示网格的维度(这里是1D网格)

lambda meta 的设计目标:
允许内核在不同块大小下复用,无需硬编码网格大小
使内核更加灵活,可以自动适应不同输入大小

add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
工作方式:
[grid] 部分指定了网格计算函数
Triton 运行时会首先调用 grid({'BLOCK_SIZE': 1024}) 获取实际网格大小,然后启动相应数量的线程块。

然后到了 triton kernel 的函数头:

@triton.jit
def add_kernel(x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):

tl.constexpr 的作用:
标记 BLOCK_SIZE 为编译时常量,在编译时而非运行时确定值
允许 Triton 编译器根据编译时常量进行优化(如循环展开)

函数体就不展开了,结合cuda 的编程方式,可以体会到很强的映射关系。

4. Triton 的 lambda meta 处理过程

        Triton 的 lambda meta 语法不是原生 Python 语法,而是一种由 Triton 编译器专门设计的领域特定语言(DSL)扩展。其工作原理大致分为语法解析阶段、编译处理阶段、代码生成阶段:

4.1. 语法解析阶段

当 Triton 遇到 kernel[grid](args) 这种语法时:

step1:  装饰器拦截

        @triton.jit 装饰器将 Python 函数标记为 Triton 内核
触发 Triton 的定制化解析流程

step2:  AST 转换

        Triton 使用 Python 的抽象语法树(AST)解析器获取代码结构
对 AST 进行转换,将特殊语法节点转换为 Triton 内部表示

step3:  Lambda Meta 处理

        识别 grid = lambda meta: ... 这种特殊模式
提取 lambda 函数体用于后续的网格计算

4.2. 编译时处理机制

网格计算 Lambda 的特殊处理

step1:  元参数字典构建

meta = {
'BLOCK_SIZE': 1024,  # 从内核调用传入
# 其他可能的编译时常量...
}

step2:  符号化执行

Triton 编译器对 lambda 体进行符号化分析
将 meta['BLOCK_SIZE'] 替换为实际值(如1024)
计算 triton.cdiv(n_elements, BLOCK_SIZE)

step3:  延迟执行设计

不像普通 Python lambda 立即执行,Triton 在编译时捕获 lambda 表达式,在代码生成阶段才实际计算网格大小

4.3. 代码生成阶段

 step1:  网格维度确定

调用 grid(meta) 获取具体网格形状,生成对应的 CUDA 网格启动配置

step2:  内核参数绑定

将 Python 参数(x,y,output)绑定到设备指针,并处理 tl.constexpr 参数的特殊传递

step3:  PTX 生成

最终生成类似如下的设备代码结构:

define void @add_kernel(..., i32 %n_elements) {%pid = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()%block_start = mul i32 %pid, 1024  // BLOCK_SIZE内联...
}

然后可以基于llvm内部后端模块生成PTX

5. triton lambda meta 与 python lambda 的对比

特性Python LambdaTriton Lambda Meta
执行时机运行时立即执行编译时延迟执行
参数类型常规 Python 对象特殊 meta 字典
可用操作完整 Python 语法受限的 Triton DSL 子集
优化方式无特别优化常量传播、循环展开等优化
返回值使用直接使用返回值用于配置内核启动参数


6. 设计原理深度解析

        这种元编程范式,允许在编译时基于参数动态生成代码,以便实现"一次编写,多配置生成"的效果。

其中用到了编译时常量传播

# 用户代码
grid = lambda meta: (triton.cdiv(n, meta['SIZE']),)

实际效果相当于

grid_size = (n + 1023) // 1024  # 当SIZE=1024时

如上所述,对其解析涉及到多阶段编译:

阶段1:解析Python AST,识别Triton特殊结构
阶段2:处理lambda meta,确定并行参数
阶段3:生成优化后的设备代码

        这种类型系统集成,其中,tl.constexpr 类型提示帮助编译器区分运行时变量(如n_elements)、编译时常量(如BLOCK_SIZE)

7. 使用常数特性实现性能优化


一些常用的 GPU 编程优化技巧,基于 meta 参数的常数性质,得到了实施。

        基于 BLOCK_SIZE 的编译时已知性可以至少完成如下三种常用优化:

(1.) 支持完全展开内存加载/存储等循环体
(2.) 支持寄存器分配(若非已知,则需要使用数组的方式,在 global mem 或shared mem上分配空间)

offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 可能被优化为寄存器数组而非内存操作


(3.) 用于边界检查的省略

当 n_elements % BLOCK_SIZE == 0 时

可以省略不必要的 mask 计算和相关分支检查代码的生成,自动进行性能优化

        这种设计最终帮助 Triton 在保持 Python 前端简洁性的同时,能够生成与手工优化 CUDA 代码相媲美的高性能GPU代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89816.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89816.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫-政务网站自动采集数据框架

前言 本文是该专栏的第81篇,后面会持续分享python爬虫干货知识,记得关注。 本文,笔者将详细介绍一个基于政务网站进行自动采集数据的爬虫框架。对此感兴趣的同学,千万别错过。 废话不多说,具体细节部分以及详细思路逻辑,跟着笔者直接往下看正文部分。(附带框架完整代码…

GitHub 趋势日报 (2025年07月19日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图1054shadPS4695n8n361remote-jobs321maigret257github-mcp-server249open_deep_res…

2025开源组件安全工具推荐OpenSCA

OpenSCA是国内最早的开源SCA平台&#xff0c;继承了商业级SCA的开源应用安全缺陷检测、多级开源依赖挖掘、纵深代码同源检测等核心能力&#xff0c;通过软件成分分析、依赖分析、特征分析、引用识别、合规分析等方法&#xff0c;深度挖掘组件中潜藏的各类安全漏洞及开源协议风险…

旅游管理实训基地建设:筑牢文旅人才培养的实践基石

随着文旅产业的蓬勃发展&#xff0c;行业对高素质、强实践的旅游管理人才需求日益迫切。旅游管理实训基地建设作为连接理论教学与行业实践的关键纽带&#xff0c;既是深化产教融合的重要载体&#xff0c;也是提升旅游管理专业人才培养质量的核心抓手。一、旅游管理实训基地建设…

网络爬虫的相关知识和操作

介绍 爬虫的定义 爬虫&#xff08;Web Crawler&#xff09;是一种自动化程序&#xff0c;用于从互联网上抓取、提取和存储网页数据。其核心功能是模拟人类浏览行为&#xff0c;访问目标网站并解析页面内容&#xff0c;最终将结构化数据保存到本地或数据库。 爬虫的工作原理 …

【vue-6】Vue3 响应式数据声明:深入理解 ref()

在 Vue3 的 Composition API 中&#xff0c;ref() 是最基础也是最常用的响应式数据声明方式之一。它为开发者提供了一种简单而强大的方式来管理组件状态。本文将深入探讨 ref() 的工作原理、使用场景以及最佳实践。 1. 什么是 ref()&#xff1f; ref() 是 Vue3 提供的一个函数&…

HTML常用标签汇总(精简版)

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>简单标记</title> </head><body>&…

【.net core】支持通过属性名称索引的泛型包装类

类/// <summary> /// 支持通过属性名称索引的泛型包装类 /// </summary> public class PropertyIndexer<T> : IEnumerable<T> {private T[] _items;private T _instance;private PropertyInfo[] _properties;private bool _caseSensitive;public Prope…

【机器学习|学习笔记】详解支持向量机(Support Vector Machine,SVM)为何要引入核函数?为何对缺失数据敏感?

【机器学习|学习笔记】详解支持向量机(Support Vector Machine,SVM)为何要引入核函数?为何对缺失数据敏感? 【机器学习|学习笔记】详解支持向量机(Support Vector Machine,SVM)为何要引入核函数?为何对缺失数据敏感? 文章目录 【机器学习|学习笔记】详解支持向量机(…

Bicep入门篇

前言 Azure Bicep 是 ARM 模板的最新版本,旨在解决开发人员在将资源部署到 Azure 时遇到的一些问题。它是一款开源工具,实际上是一种领域特定语言 (DSL),它提供了一种声明式编写基础架构的方法,该基础架构描述了虚拟机、Web 应用和网络接口等云资源的拓扑结构。它还鼓励在…

命名实体识别15年研究全景:从规则到机器学习的演进(1991-2006)

本文精读NRC Canada与NYU联合发表的经典综述《A survey of named entity recognition and classification》&#xff0c;解析NERC技术演进脉络与核心方法论 一、为什么命名实体识别&#xff08;NER&#xff09;如此重要&#xff1f; 命名实体识别&#xff08;Named Entity Rec…

eNSP综合实验(DNCP、NAT、TELET、HTTP、DNS)

1搭建实验拓扑2实验目的学习掌握eNSP中的命令3实验步骤3.1配置连接PC和客户端的交换机(仅以右侧为例)[Huawei]vlan batch 10 20 #创建vlan Info: This operation may take a few seconds. Please wait for a moment...done. [Huawei]un in en [Huawei]interface e0/0/2 [Huawei…

无人系统与安防监控中的超低延迟直播技术应用:基于大牛直播SDK的实战分享

技术背景 在 无人机、机器人 以及 智能安防 等高要求行业&#xff0c;高清视频的超低延迟传输 正在成为影响系统性能与业务决策的重要因素。无论是工业生产线的远程巡检、突发事件的应急响应&#xff0c;还是高风险环境下的智能监控与远程控制&#xff0c;视频链路的传输延迟都…

go语言学习之包

概念&#xff1a;在Go 语言中&#xff0c;包由一个或多个保存在同一目录的源码文件组成&#xff0c;包名宇目录名无关&#xff0c;但是通常大家习惯包名和目录名保持一致&#xff0c;同一目录的源码文件必须使用相同的包名。包的用途类似于其他语言的命名空间&#xff0c;可以限…

pytorch学习笔记(五)-- 计算机视觉的迁移学习

系列文章目录 pytorch学习笔记&#xff08;一&#xff09;-- pytorch深度学习框架基本知识了解 pytorch学习笔记&#xff08;二&#xff09;-- pytorch模型开发步骤详解 pytorch学习笔记&#xff08;三&#xff09;-- TensorBoard的介绍 pytorch学习笔记&#xff08;四&…

数字IC后端培训教程之数字后端项目典型项目案例解析

数字IC后端低功耗设计实现案例分享(3个power domain&#xff0c;2个voltage domain) Q1: 电路如下图&#xff0c;clk是一个很慢的时钟test_clk&#xff08;属于DFT的)&#xff0c;DFF1与and 形成一个clock gating check。跑pr 发现&#xff0c;时钟树综合CTS阶段&#xff08;C…

2025 Data Whale x PyTorch 安装学习笔记(Windows 版)

一、Anaconda 的安装与基本操作 1. 安装 Anaconda/miniconda 官方链接&#xff1a;Anaconda | Individual Edition 根据系统版本选择合适的安装包下载并安装。 2. 检验安装 打开 “开始” 菜单&#xff0c;找到 “Anaconda Prompt”&#xff08;一般在 Anaconda3 文件夹…

mac OS上docker安装zookeeper

拉取镜像&#xff1a;$ docker pull zookeeper:3.5.7 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper no matching manifest for linux/arm64/v8 in the manifest list entries报错&#xff1a;由于时M3…

设备通过4G网卡接入EasyCVR视频融合平台,出现无法播放的问题排查和解决

EasyCVR视频融合平台作为支持多协议接入、多设备集中管理的综合性视频解决方案&#xff0c;可实现各类终端设备的视频流汇聚与实时播放。近期收到用户反馈&#xff0c;在EasyCVR平台接入设备后出现视频流无法播放的情况。为帮助更多用户快速排查同类问题&#xff0c;现将具体处…

板凳-------Mysql cookbook学习 (十二--------3)

第二章 抽象数据类型和python类 2.5类定义实例&#xff1a; 学校人事管理系统中的类 import datetimeclass PersonValueError(ValueError):"""自定义异常类"""passclass PersonTypeError(TypeError):"""自定义异常类""…