项目使用到 MLIR,通过了解 triton 对 MLIR 的使用,体会到 MLIR 在较大项目中的使用方式,汇总一下。
1. Triton 概述
OpenAI Triton 是一个开源的编程语言和编译器,旨在简化 GPU 高性能计算(HPC) 的开发,特别是针对深度学习、科学计算等需要高效并行计算的领域。
既允许开发者编写高度优化的代码,又不必过度关注底层硬件细节。这样,通过简化高性能计算,可以加速新算法的实现和实验。传统 GPU 编程(如 CUDA)需要深入理解硬件架构和复杂的优化技术,而 Triton 旨在提供更高层次的抽象,降低开发门槛,但是设计 triton 语言及其编译器本身,门槛却非常高。
Triton 是基于 Python 的 DSL(领域特定语言),Triton 提供类似 Python 的语法,允许用户用简洁的代码表达并行计算逻辑,然后通过编译器优化为高效的 GPU 代码。其中,这些优化是自动化的。自动处理线程调度、内存合并(memory coalescing)、共享内存分配等底层优化,减少手动调优的工作量。Triton 在模块化与可扩展性方面下了不少功夫,它支持用户自定义内核(kernels)和优化策略,同时提供标准化的高性能算子库(如矩阵乘法、卷积等)。同时,Triton 可与 PyTorch 等深度学习框架集成,支持直接调用 Triton 内核。
在理念上,Triton 使用多级并行计算模型,借鉴 CUDA 的线程层次(thread blocks/grids),但通过更高层次的抽象(如 triton.program_id)简化编程。针对数据的局部性做优化,自动利用 GPU 的共享内存(shared memory)和寄存器,优化内存访问模式。Triton 把 LLVM 编译框架融合了进来,Triton 编译器将高级代码转换为优化的 PTX(NVIDIA GPU 的中间表示),同时结合了机器学习驱动的自动调优(auto-tuning)。在其前端,Triton 借助形式化程序语义,通过静态分析和程序变换确保代码的正确性和性能可预测性。
2. 基于预编译的包安装 triton
triton 通常跟 pytorch 一起使用;
2.1 安装 pytorch
安装一个基于 cuda 12.8 的 pytorch:
$ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
需要下载 几个 GB 的包,网络好的话会比较快,或者下班前、睡觉前安装;
验证安装:
2.2 安装triton
pip install triton
验证安装: 跑一个 tutorial 01:
$ wget https://triton-lang.org/main/_downloads/763344228ae6bc253ed1a6cf586aa30d/tutorials_python.zip
$ unzip ........$ python ./01-vector-add.py
运行结果应该如下:
3. 通过 example 了解 triton
3.1 01-vector-add.py 的源码
"""
Vector Addition
===============In this tutorial, you will write a simple vector addition using Triton.In doing so, you will learn about:* The basic programming model of Triton.* The `triton.jit` decorator, which is used to define Triton kernels.* The best practices for validating and benchmarking your custom ops against native reference implementations."""# %%
# Compute Kernel
# --------------import torchimport triton
import triton.language as tlDEVICE = triton.runtime.driver.active.get_active_torch_device()@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.y_ptr, # *Pointer* to second input vector.output_ptr, # *Pointer* to output vector.n_elements, # Size of the vector.BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.# NOTE: `constexpr` so it can be used as a shape value.):# There are multiple 'programs' processing different data. We identify which program# we are here:pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.# This program will process inputs that are offset from the initial data.# For instance, if you had a vector of length 256 and block_size of 64, the programs# would each access the elements [0:64, 64:128, 128:192, 192:256].# Note that offsets is a list of pointers:block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# Create a mask to guard memory operations against out-of-bounds accesses.mask = offsets < n_elements# Load x and y from DRAM, masking out any extra elements in case the input is not a# multiple of the block size.x = tl.load(x_ptr + offsets, mask=mask)y = tl.load(y_ptr + offsets, mask=mask)output = x + y# Write x + y back to DRAM.tl.store(output_ptr + offsets, output, mask=mask)# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes:def add(x: torch.Tensor, y: torch.Tensor):# We need to preallocate the output.output = torch.empty_like(x)assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICEn_elements = output.numel()# The SPMD launch grid denotes the number of kernel instances that run in parallel.# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].# In this case, we use a 1D grid where the size is the number of blocks:grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )# NOTE:# - Each torch.tensor object is implicitly converted into a pointer to its first element.# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.# - Don't forget to pass meta-parameters as keywords arguments.add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still# running asynchronously at this point.return output# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is 'f'{torch.max(torch.abs(output_torch - output_triton))}')# %%
# Seems like we're good to go!# %%
# Benchmark
# ---------
#
# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops.
# for different problem sizes.@triton.testing.perf_report(triton.testing.Benchmark(x_names=['size'], # Argument names to use as an x-axis for the plot.x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`.x_log=True, # x axis is logarithmic.line_arg='provider', # Argument name whose value corresponds to a different line in the plot.line_vals=['triton', 'torch'], # Possible values for `line_arg`.line_names=['Triton', 'Torch'], # Label name for the lines.styles=[('blue', '-'), ('green', '-')], # Line styles.ylabel='GB/s', # Label name for the y-axis.plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot.args={}, # Values for function arguments not in `x_names` and `y_name`.))
def benchmark(size, provider):x = torch.rand(size, device=DEVICE, dtype=torch.float32)y = torch.rand(size, device=DEVICE, dtype=torch.float32)quantiles = [0.5, 0.2, 0.8]if provider == 'torch':ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)return gbps(ms), gbps(max_ms), gbps(min_ms)# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data:
benchmark.run(print_data=True, show_plots=True)
3.2 01-vector-add.py 源码分析
业务逻辑从 Line: 86 开始:torch.manual_seed(0)
首先,设置随机函数的种子;
接着,定义了两个一维的 tensor 变量 x 和 y,并随机了其元素的值;
然后,使用 pytorch 的 + 算符计算了两个 tensor 的逐元素和: output_torch = x + y;
接下来,调用自定义 add 函数,使用 triton kernel 计算了两个 tensor 的逐元素和。
从 add 函数开始逐行注释一下:
@triton.jit
def add_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE: tl.constexpr,):pid = tl.program_id(axis=0)# 相当于 cuda 中 blockId.x,axis=0 是指 x方向block_start = pid * BLOCK_SIZE#当前block 在获取数据时的起始偏移offsets = block_start + tl.arange(0, BLOCK_SIZE)#本 block 覆盖的偏移范围mask = offsets < n_elements#offsets 的范围中,其值小于 n_el... 的话,mask 为true,否则为faulsex = tl.load(x_ptr + offsets, mask=mask)# mask 为true的话,取值y = tl.load(y_ptr + offsets, mask=mask)output = x + y#相加tl.store(output_ptr + offsets, output, mask=mask)#mask 为 true的话,存回 DRAMdef add(x: torch.Tensor, y: torch.Tensor):output = torch.empty_like(x)# 定义一个shape 跟x一样的tensor 变量。# 接下来检查 x,y,output 躺在的设备是否相同。assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE# 获取 output 这个 tensor 的元素个数,存在 n_elements 中。n_elements = output.numel()# 接下来两行代码将在正文中做一些解释:grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)return output
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
逐条说明这句的要件:
这是一个动态计算网格大小的 lambda 函数
meta 参数是一个字典,包含内核的编译时常量(这里是 BLOCK_SIZE)
triton.cdiv 是 Triton 提供的向上取整除法函数,确保所有元素都被处理
grid 计算结果是一个元组,表示网格的维度(这里是1D网格)
lambda meta 的设计目标:
允许内核在不同块大小下复用,无需硬编码网格大小
使内核更加灵活,可以自动适应不同输入大小
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
工作方式:
[grid] 部分指定了网格计算函数
Triton 运行时会首先调用 grid({'BLOCK_SIZE': 1024}) 获取实际网格大小,然后启动相应数量的线程块。
然后到了 triton kernel 的函数头:
@triton.jit
def add_kernel(x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
tl.constexpr 的作用:
标记 BLOCK_SIZE 为编译时常量,在编译时而非运行时确定值
允许 Triton 编译器根据编译时常量进行优化(如循环展开)
函数体就不展开了,结合cuda 的编程方式,可以体会到很强的映射关系。
4. Triton 的 lambda meta 处理过程
Triton 的 lambda meta 语法不是原生 Python 语法,而是一种由 Triton 编译器专门设计的领域特定语言(DSL)扩展。其工作原理大致分为语法解析阶段、编译处理阶段、代码生成阶段:
4.1. 语法解析阶段
当 Triton 遇到 kernel[grid](args) 这种语法时:
step1: 装饰器拦截
@triton.jit 装饰器将 Python 函数标记为 Triton 内核
触发 Triton 的定制化解析流程
step2: AST 转换
Triton 使用 Python 的抽象语法树(AST)解析器获取代码结构
对 AST 进行转换,将特殊语法节点转换为 Triton 内部表示
step3: Lambda Meta 处理
识别 grid = lambda meta: ... 这种特殊模式
提取 lambda 函数体用于后续的网格计算
4.2. 编译时处理机制
网格计算 Lambda 的特殊处理
step1: 元参数字典构建
meta = {
'BLOCK_SIZE': 1024, # 从内核调用传入
# 其他可能的编译时常量...
}
step2: 符号化执行
Triton 编译器对 lambda 体进行符号化分析
将 meta['BLOCK_SIZE'] 替换为实际值(如1024)
计算 triton.cdiv(n_elements, BLOCK_SIZE)
step3: 延迟执行设计
不像普通 Python lambda 立即执行,Triton 在编译时捕获 lambda 表达式,在代码生成阶段才实际计算网格大小
4.3. 代码生成阶段
step1: 网格维度确定
调用 grid(meta) 获取具体网格形状,生成对应的 CUDA 网格启动配置
step2: 内核参数绑定
将 Python 参数(x,y,output)绑定到设备指针,并处理 tl.constexpr 参数的特殊传递
step3: PTX 生成
最终生成类似如下的设备代码结构:
define void @add_kernel(..., i32 %n_elements) {%pid = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()%block_start = mul i32 %pid, 1024 // BLOCK_SIZE内联...
}
然后可以基于llvm内部后端模块生成PTX
5. triton lambda meta 与 python lambda 的对比
特性 | Python Lambda | Triton Lambda Meta |
执行时机 | 运行时立即执行 | 编译时延迟执行 |
参数类型 | 常规 Python 对象 | 特殊 meta 字典 |
可用操作 | 完整 Python 语法 | 受限的 Triton DSL 子集 |
优化方式 | 无特别优化 | 常量传播、循环展开等优化 |
返回值使用 | 直接使用返回值 | 用于配置内核启动参数 |
6. 设计原理深度解析
这种元编程范式,允许在编译时基于参数动态生成代码,以便实现"一次编写,多配置生成"的效果。
其中用到了编译时常量传播
# 用户代码
grid = lambda meta: (triton.cdiv(n, meta['SIZE']),)
实际效果相当于
grid_size = (n + 1023) // 1024 # 当SIZE=1024时
如上所述,对其解析涉及到多阶段编译:
阶段1:解析Python AST,识别Triton特殊结构
阶段2:处理lambda meta,确定并行参数
阶段3:生成优化后的设备代码
这种类型系统集成,其中,tl.constexpr 类型提示帮助编译器区分运行时变量(如n_elements)、编译时常量(如BLOCK_SIZE)
7. 使用常数特性实现性能优化
一些常用的 GPU 编程优化技巧,基于 meta 参数的常数性质,得到了实施。
基于 BLOCK_SIZE 的编译时已知性可以至少完成如下三种常用优化:
(1.) 支持完全展开内存加载/存储等循环体
(2.) 支持寄存器分配(若非已知,则需要使用数组的方式,在 global mem 或shared mem上分配空间)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 可能被优化为寄存器数组而非内存操作
(3.) 用于边界检查的省略
当 n_elements % BLOCK_SIZE == 0 时
可以省略不必要的 mask 计算和相关分支检查代码的生成,自动进行性能优化
这种设计最终帮助 Triton 在保持 Python 前端简洁性的同时,能够生成与手工优化 CUDA 代码相媲美的高性能GPU代码。